jaxrts.experimental.SiiNN.NNModel.set_attributes

NNModel.set_attributes(*filters: filterlib.Filter, raise_if_not_found: bool = True, **attributes: tp.Any) None

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Args:

*filters: Filters to select the Modules to set the attributes of. raise_if_not_found: If True (default), raises a ValueError if at least one attribute

instance is not found in one of the selected Modules.

**attributes: The attributes to set.