jaxrts.experimental.SiiNN.NNModel.eval
- NNModel.eval(**attributes)
Sets the Module to evaluation mode.
evalusesset_attributesto recursively set attributesdeterministic=Trueanduse_running_average=Trueof all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropoutandBatchNormModules.Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- Args:
**attributes: additional attributes passed to
set_attributes.