jaxrts.experimental.SiiNN.NNModel.perturb

NNModel.perturb(name: str, value: ~typing.Any, variable_type: str | type[~flax.nnx.variablelib.Variable[~typing.Any]] = <class 'flax.nnx.variablelib.Perturbation'>)

Add an zero-value variable (“perturbation”) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation variable.

Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.

Note

This creates extra dummy variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = self.perturb('xgrad', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad')  # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.shape == (1, 3)   # same as the intermediate value
>>> graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=(0, 1))
... def grad_loss(params, perturbations, inputs, targets):
...   model = nnx.merge(graphdef, params, perturbations)
...   return jnp.mean((model(inputs) - targets) ** 2)

>>> (grads, perturbations) = grad_loss(params, perturbations, x, y)
>>> # `perturbations.xgrad[...]` is the intermediate gradient
>>> assert not jnp.array_equal(perturbations.xgrad[...], jnp.zeros((1, 3)))
Args:
name: A string denoting the Module attribute name for the

perturbation value.

value: The value to take intermediate gradient. variable_type: The Variable type for the stored perturbation.

Defaulted at nnx.Perturbation.