jaxrts.experimental.SiiNN.NNModelExpandedZ.perturb

NNModelExpandedZ.perturb(name: str, value: ~typing.Any, variable_type: str | type[~flax.nnx.variablelib.Variable[~typing.Any]] = <class 'flax.nnx.variablelib.Perturbation'>)

Extract gradients of intermediate values during training.

Used with nnx.capture() to record intermediate values in the forward pass and their gradients in the backward pass. Returns the value plus whatever perturbation is stored under name in the current capture context, allowing gradient computation via nnx.grad.

The workflow has four steps: 1. Initialize perturbations with nnx.capture(model, nnx.Perturbation) 2. Run model with nnx.capture(model, nnx.Intermediate, init=perturbations) 3. Take gradients with respect to perturbations using nnx.grad 4. Combine results with nnx.merge_state(perturb_grads, intermediates)

Note

This creates extra variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __call__(self, x):
...     x2 = self.perturb('grad_of_x', x)
...     return 3 * x2

>>> model = Model()
>>> x = 1.0

>>> # Step 1: Initialize perturbations
>>> forward = nnx.capture(model, nnx.Perturbation)
>>> _, perturbations = forward(x)

>>> # Steps 2-4: Capture gradients
>>> def train_step(model, perturbations, x):
...   def loss(model, perturbations, x):
...     return nnx.capture(model, nnx.Intermediate, init=perturbations)(x)
...   (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x)
...   return nnx.merge_state(perturb_grads, sowed)

>>> metrics = train_step(model, perturbations, x)
>>> # metrics contains gradients of intermediate values
Args:

name: A string key for storing the perturbation value. value: The intermediate value to capture gradients for. You must use

the returned value (not the original) for gradient capturing to work.

variable_type: The Variable type for the stored perturbation.

Default is nnx.Perturbation.