jaxrts.experimental.SiiNN.NNModel
- class jaxrts.experimental.SiiNN.NNModel(*args: Any, **kwargs: Any)[source]
A class inheriting from
nnx.Module, adding quality of life features and defining normalization features for the input of the NN.Creates the network as a set of fully connected
flax.nnx.Linearlayers.- Parameters:
din (int) – Number of input nodes. For a typical setup, this should be
3 + no_of_atoms.dhid (list[int]) – List of integers, containing the size of each hidden layer.
dout (int) – Number of output nodes. Typically, this is
no_of_ions * (no_of_ions + 1)/2.rngs (nnx.Rngs) – A random number generator.
no_of_atoms (int or None, defaults to None) – Number of species in the plasma. If this value is not given, explicitly, it is calculated from
dinasdin - 3.
Methods
__init__(din, dhid, dout, rngs[, no_of_atoms])Creates the network as a set of fully connected
flax.nnx.Linearlayers.eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Add an zero-value variable ("perturbation") to the intermediate value.
set_attributes(*filters[, raise_if_not_found])Sets the attributes of nested Modules including the current Module.
set_norms(theta, rho, Z, k_over_qk)Set the normalization of the input layers.
sow(variable_type, name, value[, reduce_fn, ...])sow()can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.train(**attributes)Sets the Module to training mode.
Attributes
Shape of the hidden layers.
Shape of the input layer.
Shape of the output layer.
Get the input layer normalizations as a dictionary.
Shape of the full model (input, hidden, and output layers).
Normalization for theta
Normalization for rho.
Normalization the ionization (this is an array with one entry for each component.
Normalization for k_over_qk