jaxrts.experimental.SiiNN.NNModel.eval

NNModel.eval(**attributes)

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Args:

**attributes: additional attributes passed to set_attributes.