jaxrts.helpers.secant_extrema_finding

jaxrts.helpers.secant_extrema_finding(func, xmin, xmax, tol=1e-07, max_iter=100000.0)[source]

Use the secant method to find the extrema of a function within specified bounds. This is achieved by calling jax.grad() on the function func.

Parameters:
  • func (callable) – The function to minimize. It should take a single input and return a scalar output.

  • xmin (float) – The minimum bound for the variable x.

  • xmax (float) – The maximum bound for the variable x.

  • tol (float, optional) – The tolerance for the stopping criteria. The default is 1e-7.

  • max_iter (int, optional) – The maximum number of iterations to perform. The default is 100000.

Returns:

float – The x value that minimizes the function within the specified bounds.

Examples

>>> def example_func(x):
...     return (x - 2) ** 2
>>> minimum, iter = secant_minimum_finding(example_func, 0, 4)
>>> print(minimum)
2.0