Creating your own model

In jaxrts, we strongly encourage community-driven collaboration. Users are welcome to create, implement, or contribute new models to the codebase, targeting specific components of the simulation. This section provides a brief overview of how to implement a new model within the framework.

To showcase this, we’ll demonstrate how to create a simple, hypothetical model for Ionization Potential Depression (IPD) that always returns \(\pi\), measured in electron volts, for any plasma conditions.

First, create a low-level function that calculates the respective quantity from a set of parameters

import jax.numpy as jnp
from .units import Quantity, ureg
...


   def ipd_pi(
      Zi: float,
      ne: Quantity,
      ni: Quantity,
      Te: Quantity,
      Ti: Quantity,
      Zbar: float | None = None
   ) -> Quantity:

   return jnp.pi * ureg.electron_volt

Next, create the corresponding jaxrts.models.Model class. Use the existing implementations for the same model type as a reference for structure and style.

import jax
import jaxrts
from jaxrts.setup import Setup
...

class PiIPD(jaxrts.models.Model):
   """
   Hypothetical IPD Model, in which the IPD is always Pi, measured in electron volts.

   See Also
   --------
   jaxrts.ipd.ipd_pi
      Function used to calculate the IPD
   """

   # The allowed model keys for a plasma state.
   allowed_keys = ["ipd"]
   __name__ = "PiIPD"

   # Citations keys for reference
   cite_keys = ["JohnDoe.2025"]

   def __init__(self):
      super().__init__()

   # This function is required by every model.
   @jax.jit
   def evaluate(self, plasma_state: "PlasmaState", setup: Setup) -> Quantity:
      return ipd.ipd_pi(
            plasma_state.Z_free,
            plasma_state.n_e,
            plasma_state.n_i,
            plasma_state.T_e,
            plasma_state.T_i,
            Zbar=plasma_state.Z_free
      )

   # This is important for IPD models only.
   @jax.jit
   def all_element_states(
      self, plasma_state: "PlasmaState", ion_population=None
   ) -> list[jnp.ndarray]:
      out = []
      for idx, element in enumerate(plasma_state.ions):
            out.append(
               jnp.array(
                  [
                        ipd.ipd_debye_hueckel(
                           Z,
                           plasma_state.n_e,
                           plasma_state.n_i,
                           plasma_state.T_e,
                           plasma_state.T_i,
                           Zbar=plasma_state.Z_free,
                        )[idx].m_as(ureg.electron_volt)
                        for Z in jnp.arange(element.Z)
                  ]
               )
               * ureg.electron_volt
            )
      return out

   # The following is required to jit a Model
   # Here, 'children' are attributes of the class that can be traced
   # using jax, e.g. plasma_state's, floats etc., while aux_data are static arguments.
   def _tree_flatten(self):
      children = ()
      aux_data = (self.model_key)  # static values
      return (children, aux_data)

   @classmethod
   def _tree_unflatten(cls, aux_data, children):
      obj = object.__new__(cls)
      (obj.model_key) = aux_data
      return obj

Finally, register your new model by calling

jax.tree_util.register_pytree_node(
  PiModel,
  PiModel._tree_flatten,
  PiModel._tree_unflatten,
)

Congratulations—you’ve successfully created your own model!

Note

A jaxrts.models.Model provides the jaxrts.models.Model.check() and jaxrts.models.Model.prepare() methods. The former should be used to raise errors e.g., if the model is only applicable to one component systems. The other can be used in order to modify the passed jaxrts.plasmastate.PlasmaState, e.g., to set sane defaults for other, subsequent Models using jaxrts.plasmastate.PlasmaState.update_default_model().