jaxrts.experimental.SiiNN.NNModel
- class jaxrts.experimental.SiiNN.NNModel(*args: Any, **kwargs: Any)[source]
A class inheriting from
nnx.Module, adding quality of life features and defining normalization features for the input of the NN.Creates the network as a set of fully connected
flax.nnx.Linearlayers.- Parameters:
din (int) – Number of input nodes. For a typical setup, this should be
3 + no_of_atoms.dhid (list[int]) – List of integers, containing the size of each hidden layer.
dout (int) – Number of output nodes. Typically, this is
no_of_ions * (no_of_ions + 1)/2.rngs (nnx.Rngs) – A random number generator.
no_of_atoms (int or None, defaults to None) – Number of species in the plasma. If this value is not given, explicitly, it is calculated from
dinasdin - 3.
Methods
__init__(din, dhid, dout, rngs[, no_of_atoms])Creates the network as a set of fully connected
flax.nnx.Linearlayers.eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
set_attributes(*filters[, ...])Sets the attributes of nested Modules including the current Module.
set_norms(T, rho, Z, k)Set the normalization of the input layers.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.
Attributes
Shape of the hidden layers.
Shape of the input layer.
Shape of the output layer.
Get the input layer normalizations as a dictionary.
Shape of the full model (input, hidden, and output layers).
Normalization for T
Normalization for rho.
Normalization the ionization (this is an array with one entry for each component.
Normalization for k in units inverse angström