jaxrts.experimental.SiiNN.NNModel.perturb
- NNModel.perturb(name: str, value: ~typing.Any, variable_type: str | type[~flax.nnx.variablelib.Variable[~typing.Any]] = <class 'flax.nnx.variablelib.Perturbation'>)
Extract gradients of intermediate values during training.
Used with
nnx.capture()to record intermediate values in the forward pass and their gradients in the backward pass. Returns the value plus whatever perturbation is stored undernamein the current capture context, allowing gradient computation viannx.grad.The workflow has four steps: 1. Initialize perturbations with
nnx.capture(model, nnx.Perturbation)2. Run model withnnx.capture(model, nnx.Intermediate, init=perturbations)3. Take gradients with respect to perturbations usingnnx.grad4. Combine results withnnx.merge_state(perturb_grads, intermediates)Note
This creates extra variables of the same size as
value, thus occupies more memory. Use it only to debug gradients in training.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __call__(self, x): ... x2 = self.perturb('grad_of_x', x) ... return 3 * x2 >>> model = Model() >>> x = 1.0 >>> # Step 1: Initialize perturbations >>> forward = nnx.capture(model, nnx.Perturbation) >>> _, perturbations = forward(x) >>> # Steps 2-4: Capture gradients >>> def train_step(model, perturbations, x): ... def loss(model, perturbations, x): ... return nnx.capture(model, nnx.Intermediate, init=perturbations)(x) ... (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) ... return nnx.merge_state(perturb_grads, sowed) >>> metrics = train_step(model, perturbations, x) >>> # metrics contains gradients of intermediate values
- Args:
name: A string key for storing the perturbation value. value: The intermediate value to capture gradients for. You must use
the returned value (not the original) for gradient capturing to work.
- variable_type: The
Variabletype for the stored perturbation. Default is
nnx.Perturbation.
- variable_type: The