jaxrts.experimental.SiiNN.NNModel.__init__

NNModel.__init__(din: int, dhid: list[int], dout: int, rngs: Rngs, no_of_atoms: int | None = None)[source]

Creates the network as a set of fully connected flax.nnx.Linear layers.

Parameters:
  • din (int) – Number of input nodes. For a typical setup, this should be 3 + no_of_atoms.

  • dhid (list[int]) – List of integers, containing the size of each hidden layer.

  • dout (int) – Number of output nodes. Typically, this is no_of_ions * (no_of_ions + 1)/2.

  • rngs (nnx.Rngs) – A random number generator.

  • no_of_atoms (int or None, defaults to None) – Number of species in the plasma. If this value is not given, explicitly, it is calculated from din as din - 3.