jaxrts.experimental.SiiNN.NNModel

class jaxrts.experimental.SiiNN.NNModel(*args: Any, **kwargs: Any)[source]

A class inheriting from nnx.Module, adding quality of life features and defining normalization features for the input of the NN.

Creates the network as a set of fully connected flax.nnx.Linear layers.

Parameters:
  • din (int) – Number of input nodes. For a typical setup, this should be 3 + no_of_atoms.

  • dhid (list[int]) – List of integers, containing the size of each hidden layer.

  • dout (int) – Number of output nodes. Typically, this is no_of_ions * (no_of_ions + 1)/2.

  • rngs (nnx.Rngs) – A random number generator.

  • no_of_atoms (int or None, defaults to None) – Number of species in the plasma. If this value is not given, explicitly, it is calculated from din as din - 3.

Methods

__init__(din, dhid, dout, rngs[, no_of_atoms])

Creates the network as a set of fully connected flax.nnx.Linear layers.

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Add an zero-value variable ("perturbation") to the intermediate value.

set_attributes(*filters[, raise_if_not_found])

Sets the attributes of nested Modules including the current Module.

set_norms(theta, rho, Z, k_over_qk)

Set the normalization of the input layers.

sow(variable_type, name, value[, reduce_fn, ...])

sow() can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

train(**attributes)

Sets the Module to training mode.

Attributes

dhid

Shape of the hidden layers.

din

Shape of the input layer.

dout

Shape of the output layer.

norms

Get the input layer normalizations as a dictionary.

shape

Shape of the full model (input, hidden, and output layers).

norm_theta

Normalization for theta

norm_rho

Normalization for rho.

norm_Z

Normalization the ionization (this is an array with one entry for each component.

norm_k_over_qk

Normalization for k_over_qk