jaxrts.saving.load
- jaxrts.saving.load(fp, unit_reg, additional_mappings=None, *args, **kwargs)[source]
Load an object from file. Uses
json.load()under to hood, and forwards args and kwargs to this function.- Parameters:
fp – The file to be loaded from.
unit_reg – The pint unit registry to use for loading.
additional_mappings (Optional) – Additional models to be considered for loading. This is only relevant when custom models were saved.
- Returns:
The Deserialized object
Examples
>>> with open("state.json", "w") as f: >>> state = load(f, unit_reg = jaxrts.ureg)
Custom models have to be passed to
load()as shown bellow.>>> class AlwaysPiModel(jaxrts.models.Model): >>> allowed_keys = ["test"] >>> __name__ = "AlwaysPiModel" >>> def evaluate(self, plasma_state, setup) -> jnp.ndarray: >>> return jnp.array([jnp.pi]) >>> with open("file", "r") as f: >>> loaded_state = saving.load( >>> f, >>> jaxrts.ureg, >>> additional_mappings={"AlwaysPiModel": AlwaysPiModel}, >>> )