jaxrts.experimental.SiiNN.NNModel

class jaxrts.experimental.SiiNN.NNModel(*args: Any, **kwargs: Any)[source]

A class inheriting from nnx.Module, adding quality of life features and defining normalization features for the input of the NN.

Creates the network as a set of fully connected flax.nnx.Linear layers.

Parameters:
  • din (int) – Number of input nodes. For a typical setup, this should be 3 + no_of_atoms.

  • dhid (list[int]) – List of integers, containing the size of each hidden layer.

  • dout (int) – Number of output nodes. Typically, this is no_of_ions * (no_of_ions + 1)/2.

  • rngs (nnx.Rngs) – A random number generator.

  • no_of_atoms (int or None, defaults to None) – Number of species in the plasma. If this value is not given, explicitly, it is calculated from din as din - 3.

Methods

__init__(din, dhid, dout, rngs[, no_of_atoms])

Creates the network as a set of fully connected flax.nnx.Linear layers.

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

set_attributes(*filters[, ...])

Sets the attributes of nested Modules including the current Module.

set_norms(T, rho, Z, k)

Set the normalization of the input layers.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.

Attributes

dhid

Shape of the hidden layers.

din

Shape of the input layer.

dout

Shape of the output layer.

norms

Get the input layer normalizations as a dictionary.

shape

Shape of the full model (input, hidden, and output layers).

norm_T

Normalization for T

norm_rho

Normalization for rho.

norm_Z

Normalization the ionization (this is an array with one entry for each component.

norm_k

Normalization for k in units inverse angström