jaxrts.experimental.SiiNN.NNModel.sow
- NNModel.sow(variable_type: type[~flax.nnx.variablelib.Variable[~flax.nnx.module.B]] | str, name: str, value: ~flax.nnx.module.A, reduce_fn: ~typing.Callable[[~flax.nnx.module.B, ~flax.nnx.module.A], ~flax.nnx.module.B] = <function <lambda>>, init_fn: ~typing.Callable[[], ~flax.nnx.module.B] = <function <lambda>>) bool
sow()can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.sow()stores a value in a newModuleattribute, denoted byname. The value will be wrapped by aVariableof typevariable_type, which can be useful to filter for insplit(),state()andpop().By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x, add=0): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x+add) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> assert len(model.i) == 1 # tuple of length 1 >>> assert model.i[0].shape == (1, 3) >>> y = model(x, add=1) >>> assert len(model.i) == 2 # tuple of length 2 >>> assert (model.i[0] + 1 == model.i[1]).all()
Alternatively, a custom init/reduce function can be passed:
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... self.sow(nnx.Intermediate, 'product', x, ... init_fn=lambda: 1, ... reduce_fn=lambda prev, curr: prev*curr) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) >>> assert (model.sum[...] == model.product[...]).all() >>> intermediate = model.sum[...] >>> y = model(x) >>> assert (model.sum[...] == intermediate*2).all() >>> assert (model.product[...] == intermediate**2).all()
- Args:
- variable_type: The
Variabletype for the stored value. Typically
Intermediateis used to indicate an intermediate value.- name: A string denoting the
Moduleattribute name, where the sowed value is stored.
value: The value to be stored. reduce_fn: The function used to combine the existing value with the new
value. The default is to append the value to a tuple.
- init_fn: For the first value stored,
reduce_fnwill be passed the result of
init_fntogether with the value to be stored. The default is an empty tuple.
- variable_type: The