jaxrts.experimental.SiiNN.NNModel.perturb
- NNModel.perturb(name: str, value: ~typing.Any, variable_type: str | type[~flax.nnx.variablelib.Variable[~typing.Any]] = <class 'flax.nnx.variablelib.Perturbation'>)
Add an zero-value variable (“perturbation”) to the intermediate value.
The gradient of
valuewould be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients ofvalueby runningjax.gradon the perturbation variable.Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.
Note
This creates extra dummy variables of the same size as
value, thus occupies more memory. Use it only to debug gradients in training.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = self.perturb('xgrad', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 4)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'xgrad') # perturbation requires a sample input run >>> _ = model(x) >>> assert model.xgrad.shape == (1, 3) # same as the intermediate value >>> graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation) >>> # Take gradients on the Param and Perturbation variables >>> @nnx.grad(argnums=(0, 1)) ... def grad_loss(params, perturbations, inputs, targets): ... model = nnx.merge(graphdef, params, perturbations) ... return jnp.mean((model(inputs) - targets) ** 2) >>> (grads, perturbations) = grad_loss(params, perturbations, x, y) >>> # `perturbations.xgrad[...]` is the intermediate gradient >>> assert not jnp.array_equal(perturbations.xgrad[...], jnp.zeros((1, 3)))
- Args:
- name: A string denoting the
Moduleattribute name for the perturbation value.
value: The value to take intermediate gradient. variable_type: The
Variabletype for the stored perturbation.Defaulted at
nnx.Perturbation.- name: A string denoting the