jaxrts.experimental.SiiNN.NNModelExpandedZ.sow
- NNModelExpandedZ.sow(variable_type: type[~flax.nnx.variablelib.Variable[~flax.nnx.module.B]] | str, name: str, value: ~flax.nnx.module.A, reduce_fn: ~typing.Callable[[~flax.nnx.module.B, ~flax.nnx.module.A], ~flax.nnx.module.B] = <function <lambda>>, init_fn: ~typing.Callable[[], ~flax.nnx.module.B] = <function <lambda>>) bool
Store intermediate values during module execution for later extraction.
Used with
nnx.capture()decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specifiednamein a collection associated withvariable_type.By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'features', x) ... x = self.linear2(x) ... return x >>> # With the capture decorator, sow returns intermediates >>> model = Model(rngs=nnx.Rngs(0)) >>> @nnx.capture(nnx.Intermediate) ... def forward(model, x): ... return model(x) >>> result, intermediates = forward(model, jnp.ones(2)) >>> assert 'features' in intermediates
Custom init/reduce functions can be passed to control accumulation:
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... return x
- Args:
- variable_type: The
Variabletype for the stored value. Typically
Intermediateor a subclass is used.
name: A string key for storing the value in the collection. value: The value to be stored. reduce_fn: Function to combine existing and new values. Default appends
to a tuple.
- init_fn: Function providing initial value for first
reduce_fncall. Default is an empty tuple.
- variable_type: The