jaxrts.experimental.SiiNN.NNModel.sow

NNModel.sow(variable_type: type[~flax.nnx.variablelib.Variable[~flax.nnx.module.B]] | str, name: str, value: ~flax.nnx.module.A, reduce_fn: ~typing.Callable[[~flax.nnx.module.B, ~flax.nnx.module.A], ~flax.nnx.module.B] = <function <lambda>>, init_fn: ~typing.Callable[[], ~flax.nnx.module.B] = <function <lambda>>) bool

sow() can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. sow() stores a value in a new Module attribute, denoted by name. The value will be wrapped by a Variable of type variable_type, which can be useful to filter for in split(), state() and pop().

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i) == 1 # tuple of length 1
>>> assert model.i[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i) == 2 # tuple of length 2
>>> assert (model.i[0] + 1 == model.i[1]).all()

Alternatively, a custom init/reduce function can be passed:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum[...] == model.product[...]).all()
>>> intermediate = model.sum[...]

>>> y = model(x)
>>> assert (model.sum[...] == intermediate*2).all()
>>> assert (model.product[...] == intermediate**2).all()
Args:
variable_type: The Variable type for the stored value.

Typically Intermediate is used to indicate an intermediate value.

name: A string denoting the Module attribute name, where

the sowed value is stored.

value: The value to be stored. reduce_fn: The function used to combine the existing value with the new

value. The default is to append the value to a tuple.

init_fn: For the first value stored, reduce_fn will be passed the result

of init_fn together with the value to be stored. The default is an empty tuple.