jaxrts.saving.load

jaxrts.saving.load(fp, unit_reg, additional_mappings=None, *args, **kwargs)[source]

Load an object from file. Uses json.load() under to hood, and forwards args and kwargs to this function.

Parameters:
  • fp – The file to be loaded from.

  • unit_reg – The pint unit registry to use for loading.

  • additional_mappings (Optional) – Additional models to be considered for loading. This is only relevant when custom models were saved.

Returns:

The Deserialized object

Examples

>>> with open("state.json", "w") as f:
>>>     state = load(f, unit_reg = jaxrts.ureg)

Custom models have to be passed to load() as shown bellow.

>>> class AlwaysPiModel(jaxrts.models.Model):
>>>     allowed_keys = ["test"]
>>>     __name__ = "AlwaysPiModel"
>>>     def evaluate(self, plasma_state, setup) -> jnp.ndarray:
>>>         return jnp.array([jnp.pi])
>>> with open("file", "r") as f:
>>>     loaded_state = saving.load(
>>>         f,
>>>         jaxrts.ureg,
>>>         additional_mappings={"AlwaysPiModel": AlwaysPiModel},
>>>     )