Creating your own model
In jaxrts, we strongly encourage community-driven collaboration. Users are welcome to create, implement, or contribute new models to the codebase, targeting specific components of the simulation. This section provides a brief overview of how to implement a new model within the framework.
To showcase this, we’ll demonstrate how to create a simple, hypothetical model for Ionization Potential Depression (IPD) that always returns \(\pi\), measured in electron volts, for any plasma conditions.
First, create a low-level function that calculates the respective quantity from a set of parameters
import jax.numpy as jnp
from .units import Quantity, ureg
...
def ipd_pi(
Zi: float,
ne: Quantity,
ni: Quantity,
Te: Quantity,
Ti: Quantity,
Zbar: float | None = None
) -> Quantity:
return jnp.pi * ureg.electron_volt
Next, create the corresponding jaxrts.models.Model class. Use the existing implementations for the same model
type as a reference for structure and style.
import jax
import jaxrts
from jaxrts.setup import Setup
...
class PiIPD(jaxrts.models.Model):
"""
Hypothetical IPD Model, in which the IPD is always Pi, measured in electron volts.
See Also
--------
jaxrts.ipd.ipd_pi
Function used to calculate the IPD
"""
# The allowed model keys for a plasma state.
allowed_keys = ["ipd"]
__name__ = "PiIPD"
# Citations keys for reference
cite_keys = ["JohnDoe.2025"]
def __init__(self):
super().__init__()
# This function is required by every model.
@jax.jit
def evaluate(self, plasma_state: "PlasmaState", setup: Setup) -> Quantity:
return ipd.ipd_pi(
plasma_state.Z_free,
plasma_state.n_e,
plasma_state.n_i,
plasma_state.T_e,
plasma_state.T_i,
Zbar=plasma_state.Z_free
)
# This is important for IPD models only.
@jax.jit
def all_element_states(
self, plasma_state: "PlasmaState", ion_population=None
) -> list[jnp.ndarray]:
out = []
for idx, element in enumerate(plasma_state.ions):
out.append(
jnp.array(
[
ipd.ipd_debye_hueckel(
Z,
plasma_state.n_e,
plasma_state.n_i,
plasma_state.T_e,
plasma_state.T_i,
Zbar=plasma_state.Z_free,
)[idx].m_as(ureg.electron_volt)
for Z in jnp.arange(element.Z)
]
)
* ureg.electron_volt
)
return out
# The following is required to jit a Model
# Here, 'children' are attributes of the class that can be traced
# using jax, e.g. plasma_state's, floats etc., while aux_data are static arguments.
def _tree_flatten(self):
children = ()
aux_data = (self.model_key) # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
(obj.model_key) = aux_data
return obj
Finally, register your new model by calling
jax.tree_util.register_pytree_node(
PiModel,
PiModel._tree_flatten,
PiModel._tree_unflatten,
)
Congratulations—you’ve successfully created your own model!
Note
A jaxrts.models.Model provides the jaxrts.models.Model.check() and jaxrts.models.Model.prepare() methods. The former should be used to raise errors e.g., if the model is only applicable to one component systems. The other can be used in order to modify the passed jaxrts.plasmastate.PlasmaState, e.g., to set sane defaults for other, subsequent Models using jaxrts.plasmastate.PlasmaState.update_default_model().