Skip to content

Solver

felice.solver

Classes

ClipSolver

Bases: Module

Source code in felice/solver.py
class ClipSolver(eqx.Module):
    solver: AbstractSolver

    def __getattr__(self, name):
        return getattr(self.solver, name)

    def step(
        self,
        terms: PyTree[AbstractTerm],
        t0: RealScalarLike,
        t1: RealScalarLike,
        y0: Y,
        args: Args,
        solver_state: _SolverState,
        made_jump: BoolScalarLike,
    ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
        """Make a single step of the solver.

        Each step is made over the specified interval $[t_0, t_1]$.

        **Arguments:**

        - `terms`: The PyTree of terms representing the vector fields and controls.
        - `t0`: The start of the interval that the step is made over.
        - `t1`: The end of the interval that the step is made over.
        - `y0`: The current value of the solution at `t0`.
        - `args`: Any extra arguments passed to the vector field.
        - `solver_state`: Any evolving state for the solver itself, at `t0`.
        - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.
            Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there
            are no jumps and for efficiency re-use information between steps; this
            indicates that a jump has just occurred and this assumption is not true.

        **Returns:**

        A tuple of several objects:

        - The value of the solution at `t1`.
        - A local error estimate made during the step. (Used by adaptive step size
            controllers to change the step size.) May be `None` if no estimate was
            made.
        - Some dictionary of information that is passed to the solver's interpolation
            routine to calculate dense output. (Used with `SaveAt(ts=...)` or
            `SaveAt(dense=...)`.)
        - The value of the solver state at `t1`.
        - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step
            happened successfully, or if (unusually) it failed for some reason.
        """
        y1, y_error, dense_info, solver_state, result = self.solver.step(
            terms, t0, t1, y0, args, solver_state, made_jump
        )
        y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)
        return y1_clipped, y_error, dense_info, solver_state, result
Functions
step(terms: PyTree[AbstractTerm], t0: RealScalarLike, t1: RealScalarLike, y0: Y, args: Args, solver_state: _SolverState, made_jump: BoolScalarLike) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]

Make a single step of the solver.

Each step is made over the specified interval \([t_0, t_1]\).

Arguments:

  • terms: The PyTree of terms representing the vector fields and controls.
  • t0: The start of the interval that the step is made over.
  • t1: The end of the interval that the step is made over.
  • y0: The current value of the solution at t0.
  • args: Any extra arguments passed to the vector field.
  • solver_state: Any evolving state for the solver itself, at t0.
  • made_jump: Whether there was a discontinuity in the vector field at t0. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.

Returns:

A tuple of several objects:

  • The value of the solution at t1.
  • A local error estimate made during the step. (Used by adaptive step size controllers to change the step size.) May be None if no estimate was made.
  • Some dictionary of information that is passed to the solver's interpolation routine to calculate dense output. (Used with SaveAt(ts=...) or SaveAt(dense=...).)
  • The value of the solver state at t1.
  • An integer (corresponding to diffrax.RESULTS) indicating whether the step happened successfully, or if (unusually) it failed for some reason.
Source code in felice/solver.py
def step(
    self,
    terms: PyTree[AbstractTerm],
    t0: RealScalarLike,
    t1: RealScalarLike,
    y0: Y,
    args: Args,
    solver_state: _SolverState,
    made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
    """Make a single step of the solver.

    Each step is made over the specified interval $[t_0, t_1]$.

    **Arguments:**

    - `terms`: The PyTree of terms representing the vector fields and controls.
    - `t0`: The start of the interval that the step is made over.
    - `t1`: The end of the interval that the step is made over.
    - `y0`: The current value of the solution at `t0`.
    - `args`: Any extra arguments passed to the vector field.
    - `solver_state`: Any evolving state for the solver itself, at `t0`.
    - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.
        Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there
        are no jumps and for efficiency re-use information between steps; this
        indicates that a jump has just occurred and this assumption is not true.

    **Returns:**

    A tuple of several objects:

    - The value of the solution at `t1`.
    - A local error estimate made during the step. (Used by adaptive step size
        controllers to change the step size.) May be `None` if no estimate was
        made.
    - Some dictionary of information that is passed to the solver's interpolation
        routine to calculate dense output. (Used with `SaveAt(ts=...)` or
        `SaveAt(dense=...)`.)
    - The value of the solver state at `t1`.
    - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step
        happened successfully, or if (unusually) it failed for some reason.
    """
    y1, y_error, dense_info, solver_state, result = self.solver.step(
        terms, t0, t1, y0, args, solver_state, made_jump
    )
    y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)
    return y1_clipped, y_error, dense_info, solver_state, result