jaxrts.experimental.SiiNN.NNModel.set_norms

NNModel.set_norms(theta: float, rho: float, Z: list[float], k_over_qk: float)[source]

Set the normalization of the input layers.