jaxrts.experimental.SiiNN.NNModel.train
- NNModel.train(**attributes)
Sets the Module to training mode.
trainusesset_attributesto recursively set attributesdeterministic=Falseanduse_running_average=Falseof all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropoutandBatchNormModules.Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- Args:
**attributes: additional attributes passed to
set_attributes.