jaxrts.experimental.SiiNN.NNModel.__init__
- NNModel.__init__(din: int, dhid: list[int], dout: int, rngs: Rngs, no_of_atoms: int | None = None)[source]
Creates the network as a set of fully connected
flax.nnx.Linearlayers.- Parameters:
din (int) – Number of input nodes. For a typical setup, this should be
3 + no_of_atoms.dhid (list[int]) – List of integers, containing the size of each hidden layer.
dout (int) – Number of output nodes. Typically, this is
no_of_ions * (no_of_ions + 1)/2.rngs (nnx.Rngs) – A random number generator.
no_of_atoms (int or None, defaults to None) – Number of species in the plasma. If this value is not given, explicitly, it is calculated from
dinasdin - 3.