jaxrts.experimental.SiiNN.NNModelExpandedZ.sow

NNModelExpandedZ.sow(variable_type: type[~flax.nnx.variablelib.Variable[~flax.nnx.module.B]] | str, name: str, value: ~flax.nnx.module.A, reduce_fn: ~typing.Callable[[~flax.nnx.module.B, ~flax.nnx.module.A], ~flax.nnx.module.B] = <function <lambda>>, init_fn: ~typing.Callable[[], ~flax.nnx.module.B] = <function <lambda>>) bool

Store intermediate values during module execution for later extraction.

Used with nnx.capture() decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specified name in a collection associated with variable_type.

By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'features', x)
...     x = self.linear2(x)
...     return x

>>> # With the capture decorator, sow returns intermediates
>>> model = Model(rngs=nnx.Rngs(0))
>>> @nnx.capture(nnx.Intermediate)
... def forward(model, x):
...   return model(x)
>>> result, intermediates = forward(model, jnp.ones(2))
>>> assert 'features' in intermediates

Custom init/reduce functions can be passed to control accumulation:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     return x
Args:
variable_type: The Variable type for the stored value.

Typically Intermediate or a subclass is used.

name: A string key for storing the value in the collection. value: The value to be stored. reduce_fn: Function to combine existing and new values. Default appends

to a tuple.

init_fn: Function providing initial value for first reduce_fn call.

Default is an empty tuple.