jaxrts.experimental.SiiNN

Interpolate the Static structure factors by a neural network. In this file, we define the NN layout, and also create a jaxrts.Model which allows for easily using a trained NN with jaxrts.

Models have to be trained, for each sample type, separately. Tools for doing so are provided in the tools/SiiInterpolation/ directory of the jaxrts repository. The trained network is saved as an orbax checkpoint (with slight additions to save properties of the net architecture).

Functions

set_sharding(x)

Classes

NNModel(*args, **kwargs)

A class inheriting from nnx.Module, adding quality of life features and defining normalization features for the input of the NN.

NNSiiModel(checkpoint_dir)

A jaxrts.model.IonFeatModel to use a neural network to obtain ion-ion static structure factors.