Skip to content

Neuron Models

felice.neuron_models

Classes

Boomerang

Bases: Module

Source code in felice/neuron_models/boomerang.py
class Boomerang(eqx.Module):
    rtol: float = eqx.field(static=True)
    atol: float = eqx.field(static=True)

    u0: float = eqx.field(static=True)
    v0: float = eqx.field(static=True)

    alpha: float = eqx.field(static=True)  # I_n0 / I_bias ratio
    beta: float = eqx.field(static=True)  # k / U_t (inverse thermal scale)
    gamma: float = eqx.field(static=True)  # coupling coefficient
    rho: float = eqx.field(static=True)  # tanh steepness
    sigma: float = eqx.field(static=True)  # bias scaling (s * I_bias)

    dtype: DTypeLike = eqx.field(static=True)

    def __init__(
        self,
        *,
        atol: float = 1e-6,
        rtol: float = 1e-4,
        alpha: float = 0.0129,
        beta: float = 15.6,
        gamma: float = 0.26,
        rho: float = 30.0,
        sigma: float = 0.6,
        dtype: DTypeLike = jnp.float32,
    ):
        r"""Initialize the WereRabbit neuron model.

        Args:
            key: JAX random key for weight initialization.
            n_neurons: Number of neurons in this layer.
            in_size: Number of input connections (excluding recurrent connections).
            wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
            rtol: Relative tolerance for the spiking fixpoint calculation.
            atol: Absolute tolerance for the spiking fixpoint calculation.
            alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
            beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
            gamma: Coupling parameter $\gamma = 26e^{-2}$
            rho: Steepness of the tanh function $\rho$ (default: 5)
            sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
            wlim: Limit for weight initialization. If None, uses init_weights.
            wmean: Mean value for weight initialization.
            init_weights: Optional initial weight values. If None, weights are randomly initialized.
            fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').
            dtype: Data type for arrays (default: float32).
        """
        self.dtype = dtype

        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.rho = rho
        self.sigma = sigma

        self.rtol = rtol
        self.atol = atol

        def fn(y, _):
            return self.vector_field(y[0], y[1])

        solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)
        y0 = (jnp.array(0.3), jnp.array(0.3))
        u0, v0 = optx.root_find(fn, solver, y0).value
        self.u0 = u0.item()
        self.v0 = v0.item()

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
        """Initialize the neuron state variables.

        Args:
            n_neurons: Number of neurons to initialize.

        Returns:
            Initial state array of shape (neurons, 3) containing [u, v],
            where u and v are the predator/prey membrane voltages.
        """

        u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)
        v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)
        x = jnp.stack([u, v], axis=1)
        return x

    def vector_field(
        self, u: Float[Array, "..."], v: Float[Array, "..."]
    ) -> Tuple[Float[Array, "..."], Float[Array, "..."]]:
        alpha = self.alpha
        beta = self.beta
        gamma = self.gamma
        sigma = self.sigma
        rho = self.rho

        z = jax.nn.tanh(rho * (v - u))
        du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z
        dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z

        return du, dv

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 2"]:
        """Compute time derivatives of the neuron state variables.

        This implements the WereRabbit dynamics

            - du/dt: Predator dynamics
            - dv/dt: WerePrey dynamics

        Args:
            t: Current simulation time (unused but required by framework).
            y: State array of shape (neurons, 2) containing [u, v].
            args: Additional arguments (unused but required by framework).

        Returns:
            Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
        """
        u = y[:, 0]
        v = y[:, 1]

        du, dv = self.vector_field(u, v)
        dxdt = jnp.stack([du, dv], axis=1)

        return dxdt

    def spike_condition(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        **kwargs: Dict[str, Any],
    ) -> Float[Array, " neurons"]:
        """Compute spike condition for event detection.

        A spike is triggered when the system reach to a fixpoint.

        INFO:
            `has_spiked` is use to the system don't detect a continuos
            spike when reach a fixpoint.

        Args:
            t: Current simulation time (unused but required by the framework).
            y: State array of shape (neurons, 3) containing [u, v, has_spiked].
            **kwargs: Additional keyword arguments (unused).

        Returns:
            Spike condition array of shape (neurons,). Positive values indicate spike.
        """
        _atol = self.atol
        _rtol = self.rtol
        _norm = optx.rms_norm

        vf = self.dynamics(t, y, {})

        @jax.vmap
        def calculate_norm(vf, y):
            return _atol + _rtol * _norm(y) - _norm(vf)

        base_cond = calculate_norm(vf, y).repeat(2)

        return base_cond
Functions
__init__(*, atol: float = 1e-06, rtol: float = 0.0001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 30.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)

Initialize the WereRabbit neuron model.

Parameters:

Name Type Description Default
key

JAX random key for weight initialization.

required
n_neurons

Number of neurons in this layer.

required
in_size

Number of input connections (excluding recurrent connections).

required
wmask

Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).

required
rtol float

Relative tolerance for the spiking fixpoint calculation.

0.0001
atol float

Absolute tolerance for the spiking fixpoint calculation.

1e-06
alpha float

Current scaling parameter \(\alpha = I_{n0}/I_{bias}\) (default: 0.0129)

0.0129
beta float

Exponential slope \(\beta = \kappa/U_t\) (default: 15.6)

15.6
gamma float

Coupling parameter \(\gamma = 26e^{-2}\)

0.26
rho float

Steepness of the tanh function \(\rho\) (default: 5)

30.0
sigma float

Fixpoint distance scaling \(\sigma\) (default: 0.6)

0.6
wlim

Limit for weight initialization. If None, uses init_weights.

required
wmean

Mean value for weight initialization.

required
init_weights

Optional initial weight values. If None, weights are randomly initialized.

required
fan_in_mode

Mode for fan-in based weight initialization ('sqrt', 'linear').

required
dtype DTypeLike

Data type for arrays (default: float32).

float32
Source code in felice/neuron_models/boomerang.py
def __init__(
    self,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-4,
    alpha: float = 0.0129,
    beta: float = 15.6,
    gamma: float = 0.26,
    rho: float = 30.0,
    sigma: float = 0.6,
    dtype: DTypeLike = jnp.float32,
):
    r"""Initialize the WereRabbit neuron model.

    Args:
        key: JAX random key for weight initialization.
        n_neurons: Number of neurons in this layer.
        in_size: Number of input connections (excluding recurrent connections).
        wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
        rtol: Relative tolerance for the spiking fixpoint calculation.
        atol: Absolute tolerance for the spiking fixpoint calculation.
        alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
        beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
        gamma: Coupling parameter $\gamma = 26e^{-2}$
        rho: Steepness of the tanh function $\rho$ (default: 5)
        sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
        wlim: Limit for weight initialization. If None, uses init_weights.
        wmean: Mean value for weight initialization.
        init_weights: Optional initial weight values. If None, weights are randomly initialized.
        fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').
        dtype: Data type for arrays (default: float32).
    """
    self.dtype = dtype

    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.rho = rho
    self.sigma = sigma

    self.rtol = rtol
    self.atol = atol

    def fn(y, _):
        return self.vector_field(y[0], y[1])

    solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)
    y0 = (jnp.array(0.3), jnp.array(0.3))
    u0, v0 = optx.root_find(fn, solver, y0).value
    self.u0 = u0.item()
    self.v0 = v0.item()
init_state(n_neurons: int) -> Float[Array, 'neurons 2']

Initialize the neuron state variables.

Parameters:

Name Type Description Default
n_neurons int

Number of neurons to initialize.

required

Returns:

Type Description
Float[Array, 'neurons 2']

Initial state array of shape (neurons, 3) containing [u, v],

Float[Array, 'neurons 2']

where u and v are the predator/prey membrane voltages.

Source code in felice/neuron_models/boomerang.py
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
    """Initialize the neuron state variables.

    Args:
        n_neurons: Number of neurons to initialize.

    Returns:
        Initial state array of shape (neurons, 3) containing [u, v],
        where u and v are the predator/prey membrane voltages.
    """

    u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)
    v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)
    x = jnp.stack([u, v], axis=1)
    return x
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']

Compute time derivatives of the neuron state variables.

This implements the WereRabbit dynamics

- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics

Parameters:

Name Type Description Default
t float

Current simulation time (unused but required by framework).

required
y Float[Array, 'neurons 2']

State array of shape (neurons, 2) containing [u, v].

required
args Dict[str, Any]

Additional arguments (unused but required by framework).

required

Returns:

Type Description
Float[Array, 'neurons 2']

Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].

Source code in felice/neuron_models/boomerang.py
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 2"]:
    """Compute time derivatives of the neuron state variables.

    This implements the WereRabbit dynamics

        - du/dt: Predator dynamics
        - dv/dt: WerePrey dynamics

    Args:
        t: Current simulation time (unused but required by framework).
        y: State array of shape (neurons, 2) containing [u, v].
        args: Additional arguments (unused but required by framework).

    Returns:
        Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
    """
    u = y[:, 0]
    v = y[:, 1]

    du, dv = self.vector_field(u, v)
    dxdt = jnp.stack([du, dv], axis=1)

    return dxdt
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']

Compute spike condition for event detection.

A spike is triggered when the system reach to a fixpoint.

INFO

has_spiked is use to the system don't detect a continuos spike when reach a fixpoint.

Parameters:

Name Type Description Default
t float

Current simulation time (unused but required by the framework).

required
y Float[Array, 'neurons 2']

State array of shape (neurons, 3) containing [u, v, has_spiked].

required
**kwargs Dict[str, Any]

Additional keyword arguments (unused).

{}

Returns:

Type Description
Float[Array, ' neurons']

Spike condition array of shape (neurons,). Positive values indicate spike.

Source code in felice/neuron_models/boomerang.py
def spike_condition(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    **kwargs: Dict[str, Any],
) -> Float[Array, " neurons"]:
    """Compute spike condition for event detection.

    A spike is triggered when the system reach to a fixpoint.

    INFO:
        `has_spiked` is use to the system don't detect a continuos
        spike when reach a fixpoint.

    Args:
        t: Current simulation time (unused but required by the framework).
        y: State array of shape (neurons, 3) containing [u, v, has_spiked].
        **kwargs: Additional keyword arguments (unused).

    Returns:
        Spike condition array of shape (neurons,). Positive values indicate spike.
    """
    _atol = self.atol
    _rtol = self.rtol
    _norm = optx.rms_norm

    vf = self.dynamics(t, y, {})

    @jax.vmap
    def calculate_norm(vf, y):
        return _atol + _rtol * _norm(y) - _norm(vf)

    base_cond = calculate_norm(vf, y).repeat(2)

    return base_cond

FHNRS

Bases: Module

FitzHugh-Nagumo neuron model

Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast and slow currents to produce oscillatory spiking behavior.

The dynamics are governed by:

\[ \begin{align} C\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\ \frac{dv_{slow}}{dt} &= \frac{v - v_{slow}}{\tau_{slow}} \\ \frac{dI_{app}}{dt} &= -\frac{I_{app}}{\tau_{syn}} \end{align} \]

where the currents are:

  • \(I_{passive} = g_{max}(v - E_{rev})\)
  • \(I_{fast} = a_{fast} \tanh(v - v_{off,fast})\)
  • \(I_{slow} = a_{slow} \tanh(v_{slow} - v_{off,slow})\)
References
  • Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.

Attributes:

Name Type Description
reset_grad_preserve

Preserve the gradient when the neuron spikes by doing a soft reset.

gmax_pasive float

Maximal conductance of the passive current.

Erev_pasive float

Reversal potential for the passive current.

a_fast float

Amplitude parameter for the fast current dynamics.

voff_fast float

Voltage offset for the fast current activation.

tau_fast float

Time constant for the fast current (typically zero for instantaneous).

a_slow float

Amplitude parameter for the slow current dynamics.

voff_slow float

Voltage offset for the slow current activation.

tau_slow float

Time constant for the slow recovery variable.

vthr float

Voltage threshold for spike generation.

C float

Membrane capacitance.

tsyn float

Synaptic time constant for input current decay.

weights float

Synaptic weight matrix of shape (in_plus_neurons, neurons).

Source code in felice/neuron_models/fhn.py
class FHNRS(eqx.Module):
    r"""FitzHugh-Nagumo neuron model

    Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by
    Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast
    and slow currents to produce oscillatory spiking behavior.

    The dynamics are governed by:

    $$
    \begin{align}
        C\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\
        \frac{dv_{slow}}{dt} &= \frac{v - v_{slow}}{\tau_{slow}} \\
        \frac{dI_{app}}{dt} &= -\frac{I_{app}}{\tau_{syn}}
    \end{align}
    $$

    where the currents are:

    - $I_{passive} = g_{max}(v - E_{rev})$
    - $I_{fast} = a_{fast} \tanh(v - v_{off,fast})$
    - $I_{slow} = a_{slow} \tanh(v_{slow} - v_{off,slow})$

    References:
        - Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.

    Attributes:
        reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.
        gmax_pasive: Maximal conductance of the passive current.
        Erev_pasive: Reversal potential for the passive current.
        a_fast: Amplitude parameter for the fast current dynamics.
        voff_fast: Voltage offset for the fast current activation.
        tau_fast: Time constant for the fast current (typically zero for instantaneous).
        a_slow: Amplitude parameter for the slow current dynamics.
        voff_slow: Voltage offset for the slow current activation.
        tau_slow: Time constant for the slow recovery variable.
        vthr: Voltage threshold for spike generation.
        C: Membrane capacitance.
        tsyn: Synaptic time constant for input current decay.
        weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).
    """

    # Pasive parameters
    gmax_pasive: float = eqx.field(static=True)
    Erev_pasive: float = eqx.field(static=True)

    # Fast current
    a_fast: float = eqx.field(static=True)
    voff_fast: float = eqx.field(static=True)
    tau_fast: float = eqx.field(static=True)

    # Slow current
    a_slow: float = eqx.field(static=True)
    voff_slow: float = eqx.field(static=True)
    tau_slow: float = eqx.field(static=True)

    # Neuron threshold
    vthr: float = eqx.field(static=True)
    C: float = eqx.field(static=True, default=1.0)

    # Input synaptic time constant
    tsyn: float = eqx.field(static=True)

    dtype: DTypeLike = eqx.field(static=True)

    def __init__(
        self,
        *,
        tsyn: Union[int, float, jnp.ndarray] = 1.0,
        C: Union[int, float, jnp.ndarray] = 1.0,
        gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,
        Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,
        a_fast: Union[int, float, jnp.ndarray] = -2.0,
        voff_fast: Union[int, float, jnp.ndarray] = 0.0,
        tau_fast: Union[int, float, jnp.ndarray] = 0.0,
        a_slow: Union[int, float, jnp.ndarray] = 2.0,
        voff_slow: Union[int, float, jnp.ndarray] = 0.0,
        tau_slow: Union[int, float, jnp.ndarray] = 50.0,
        vthr: Union[int, float, jnp.ndarray] = 2.0,
        dtype: DTypeLike = jnp.float32,
    ):
        """Initialize the FitzHugh-Nagumo neuron model.

        Args:
            tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.
            C: Membrane capacitance. Can be scalar or per-neuron array.
            gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.
            Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.
            a_fast: Amplitude of fast current. Can be scalar or per-neuron array.
            voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.
            tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
            a_slow: Amplitude of slow current. Can be scalar or per-neuron array.
            voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.
            tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.
            vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.
            dtype: Data type for arrays (default: float32).
        """
        self.dtype = dtype

        self.tsyn = tsyn
        self.C = C
        self.gmax_pasive = gmax_pasive
        self.Erev_pasive = Erev_pasive
        self.a_fast = a_fast
        self.voff_fast = voff_fast
        self.tau_fast = tau_fast
        self.a_slow = a_slow
        self.voff_slow = voff_slow
        self.tau_slow = tau_slow
        self.vthr = vthr

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 3"]:
        """Initialize the neuron state variables.

        Args:
            n_neurons: Number of neurons to initialize.

        Returns:
            Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
            where v is membrane voltage, v_slow is the slow recovery variable,
            and i_app is the applied synaptic current.
        """
        return jnp.zeros((n_neurons, 3), dtype=self.dtype)

    def IV_inst(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
        """Compute instantaneous I-V relationship with fast and slow currents at rest.

        Args:
            v: Membrane voltage.
            Vrest: Resting voltage for both fast and slow currents (default: 0).

        Returns:
            Total current at voltage v with both fast and slow currents evaluated at Vrest.
        """
        I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
        I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)
        I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)

        return I_pasive + I_fast + I_slow

    def IV_fast(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
        """Compute I-V relationship with fast current at voltage v and slow current at rest.

        Args:
            v: Membrane voltage for passive and fast currents.
            Vrest: Resting voltage for slow current (default: 0).

        Returns:
            Total current with fast dynamics responding to v and slow current at Vrest.
        """
        I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
        I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
        I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)

        return I_pasive + I_fast + I_slow

    def IV_slow(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
        """Compute steady-state I-V relationship with all currents at voltage v.

        Args:
            v: Membrane voltage for all currents.
            Vrest: Unused parameter for API consistency (default: 0).

        Returns:
            Total steady-state current with all currents responding to v.
        """
        I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
        I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
        I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)

        return I_pasive + I_fast + I_slow

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 3"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 3"]:
        """Compute time derivatives of the neuron state variables.

        This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:
        - dv/dt: Fast membrane voltage dynamics
        - dv_slow/dt: Slow recovery variable dynamics
        - di_app/dt: Synaptic current decay

        Args:
            t: Current simulation time (unused but required by framework).
            y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
            args: Additional arguments (unused but required by framework).

        Returns:
            Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
        """
        v = y[:, 0]
        v_slow = y[:, 1]
        i_app = y[:, 2]

        I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
        I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
        I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)

        i_sum = I_pasive + I_fast + I_slow

        dv_dt = (i_app - i_sum) / self.C
        dvslow_dt = (v - v_slow) / self.tau_slow
        di_dt = -i_app / self.tsyn

        return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)

    def spike_condition(
        self,
        t: float,
        y: Float[Array, "neurons 3"],
        **kwargs: Dict[str, Any],
    ) -> Float[Array, " neurons"]:
        """Compute spike condition for event detection.

        A spike is triggered when this function crosses zero (v >= vthr).

        Args:
            t: Current simulation time (unused but required by event detection).
            y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
            **kwargs: Additional keyword arguments (unused).

        Returns:
            Spike condition array of shape (neurons,). Positive values indicate v > vthr.
        """
        return y[:, 0] - self.vthr
Functions
__init__(*, tsyn: Union[int, float, jnp.ndarray] = 1.0, C: Union[int, float, jnp.ndarray] = 1.0, gmax_pasive: Union[int, float, jnp.ndarray] = 1.0, Erev_pasive: Union[int, float, jnp.ndarray] = 0.0, a_fast: Union[int, float, jnp.ndarray] = -2.0, voff_fast: Union[int, float, jnp.ndarray] = 0.0, tau_fast: Union[int, float, jnp.ndarray] = 0.0, a_slow: Union[int, float, jnp.ndarray] = 2.0, voff_slow: Union[int, float, jnp.ndarray] = 0.0, tau_slow: Union[int, float, jnp.ndarray] = 50.0, vthr: Union[int, float, jnp.ndarray] = 2.0, dtype: DTypeLike = jnp.float32)

Initialize the FitzHugh-Nagumo neuron model.

Parameters:

Name Type Description Default
tsyn Union[int, float, ndarray]

Synaptic time constant for input current decay. Can be scalar or per-neuron array.

1.0
C Union[int, float, ndarray]

Membrane capacitance. Can be scalar or per-neuron array.

1.0
gmax_pasive Union[int, float, ndarray]

Maximal conductance of passive current. Can be scalar or per-neuron array.

1.0
Erev_pasive Union[int, float, ndarray]

Reversal potential for passive current. Can be scalar or per-neuron array.

0.0
a_fast Union[int, float, ndarray]

Amplitude of fast current. Can be scalar or per-neuron array.

-2.0
voff_fast Union[int, float, ndarray]

Voltage offset for fast current activation. Can be scalar or per-neuron array.

0.0
tau_fast Union[int, float, ndarray]

Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.

0.0
a_slow Union[int, float, ndarray]

Amplitude of slow current. Can be scalar or per-neuron array.

2.0
voff_slow Union[int, float, ndarray]

Voltage offset for slow current activation. Can be scalar or per-neuron array.

0.0
tau_slow Union[int, float, ndarray]

Time constant for slow recovery variable. Can be scalar or per-neuron array.

50.0
vthr Union[int, float, ndarray]

Voltage threshold for spike generation. Can be scalar or per-neuron array.

2.0
dtype DTypeLike

Data type for arrays (default: float32).

float32
Source code in felice/neuron_models/fhn.py
def __init__(
    self,
    *,
    tsyn: Union[int, float, jnp.ndarray] = 1.0,
    C: Union[int, float, jnp.ndarray] = 1.0,
    gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,
    Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,
    a_fast: Union[int, float, jnp.ndarray] = -2.0,
    voff_fast: Union[int, float, jnp.ndarray] = 0.0,
    tau_fast: Union[int, float, jnp.ndarray] = 0.0,
    a_slow: Union[int, float, jnp.ndarray] = 2.0,
    voff_slow: Union[int, float, jnp.ndarray] = 0.0,
    tau_slow: Union[int, float, jnp.ndarray] = 50.0,
    vthr: Union[int, float, jnp.ndarray] = 2.0,
    dtype: DTypeLike = jnp.float32,
):
    """Initialize the FitzHugh-Nagumo neuron model.

    Args:
        tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.
        C: Membrane capacitance. Can be scalar or per-neuron array.
        gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.
        Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.
        a_fast: Amplitude of fast current. Can be scalar or per-neuron array.
        voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.
        tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
        a_slow: Amplitude of slow current. Can be scalar or per-neuron array.
        voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.
        tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.
        vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.
        dtype: Data type for arrays (default: float32).
    """
    self.dtype = dtype

    self.tsyn = tsyn
    self.C = C
    self.gmax_pasive = gmax_pasive
    self.Erev_pasive = Erev_pasive
    self.a_fast = a_fast
    self.voff_fast = voff_fast
    self.tau_fast = tau_fast
    self.a_slow = a_slow
    self.voff_slow = voff_slow
    self.tau_slow = tau_slow
    self.vthr = vthr
init_state(n_neurons: int) -> Float[Array, 'neurons 3']

Initialize the neuron state variables.

Parameters:

Name Type Description Default
n_neurons int

Number of neurons to initialize.

required

Returns:

Type Description
Float[Array, 'neurons 3']

Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],

Float[Array, 'neurons 3']

where v is membrane voltage, v_slow is the slow recovery variable,

Float[Array, 'neurons 3']

and i_app is the applied synaptic current.

Source code in felice/neuron_models/fhn.py
def init_state(self, n_neurons: int) -> Float[Array, "neurons 3"]:
    """Initialize the neuron state variables.

    Args:
        n_neurons: Number of neurons to initialize.

    Returns:
        Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
        where v is membrane voltage, v_slow is the slow recovery variable,
        and i_app is the applied synaptic current.
    """
    return jnp.zeros((n_neurons, 3), dtype=self.dtype)
IV_inst(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]

Compute instantaneous I-V relationship with fast and slow currents at rest.

Parameters:

Name Type Description Default
v Float[Array, ...]

Membrane voltage.

required
Vrest float

Resting voltage for both fast and slow currents (default: 0).

0

Returns:

Type Description
Float[Array, ...]

Total current at voltage v with both fast and slow currents evaluated at Vrest.

Source code in felice/neuron_models/fhn.py
def IV_inst(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
    """Compute instantaneous I-V relationship with fast and slow currents at rest.

    Args:
        v: Membrane voltage.
        Vrest: Resting voltage for both fast and slow currents (default: 0).

    Returns:
        Total current at voltage v with both fast and slow currents evaluated at Vrest.
    """
    I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
    I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)
    I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)

    return I_pasive + I_fast + I_slow
IV_fast(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]

Compute I-V relationship with fast current at voltage v and slow current at rest.

Parameters:

Name Type Description Default
v Float[Array, ...]

Membrane voltage for passive and fast currents.

required
Vrest float

Resting voltage for slow current (default: 0).

0

Returns:

Type Description
Float[Array, ...]

Total current with fast dynamics responding to v and slow current at Vrest.

Source code in felice/neuron_models/fhn.py
def IV_fast(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
    """Compute I-V relationship with fast current at voltage v and slow current at rest.

    Args:
        v: Membrane voltage for passive and fast currents.
        Vrest: Resting voltage for slow current (default: 0).

    Returns:
        Total current with fast dynamics responding to v and slow current at Vrest.
    """
    I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
    I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
    I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)

    return I_pasive + I_fast + I_slow
IV_slow(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]

Compute steady-state I-V relationship with all currents at voltage v.

Parameters:

Name Type Description Default
v Float[Array, ...]

Membrane voltage for all currents.

required
Vrest float

Unused parameter for API consistency (default: 0).

0

Returns:

Type Description
Float[Array, ...]

Total steady-state current with all currents responding to v.

Source code in felice/neuron_models/fhn.py
def IV_slow(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
    """Compute steady-state I-V relationship with all currents at voltage v.

    Args:
        v: Membrane voltage for all currents.
        Vrest: Unused parameter for API consistency (default: 0).

    Returns:
        Total steady-state current with all currents responding to v.
    """
    I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
    I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
    I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)

    return I_pasive + I_fast + I_slow
dynamics(t: float, y: Float[Array, 'neurons 3'], args: Dict[str, Any]) -> Float[Array, 'neurons 3']

Compute time derivatives of the neuron state variables.

This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents: - dv/dt: Fast membrane voltage dynamics - dv_slow/dt: Slow recovery variable dynamics - di_app/dt: Synaptic current decay

Parameters:

Name Type Description Default
t float

Current simulation time (unused but required by framework).

required
y Float[Array, 'neurons 3']

State array of shape (neurons, 3) containing [v, v_slow, i_app].

required
args Dict[str, Any]

Additional arguments (unused but required by framework).

required

Returns:

Type Description
Float[Array, 'neurons 3']

Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].

Source code in felice/neuron_models/fhn.py
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 3"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 3"]:
    """Compute time derivatives of the neuron state variables.

    This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:
    - dv/dt: Fast membrane voltage dynamics
    - dv_slow/dt: Slow recovery variable dynamics
    - di_app/dt: Synaptic current decay

    Args:
        t: Current simulation time (unused but required by framework).
        y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
        args: Additional arguments (unused but required by framework).

    Returns:
        Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
    """
    v = y[:, 0]
    v_slow = y[:, 1]
    i_app = y[:, 2]

    I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
    I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
    I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)

    i_sum = I_pasive + I_fast + I_slow

    dv_dt = (i_app - i_sum) / self.C
    dvslow_dt = (v - v_slow) / self.tau_slow
    di_dt = -i_app / self.tsyn

    return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)
spike_condition(t: float, y: Float[Array, 'neurons 3'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']

Compute spike condition for event detection.

A spike is triggered when this function crosses zero (v >= vthr).

Parameters:

Name Type Description Default
t float

Current simulation time (unused but required by event detection).

required
y Float[Array, 'neurons 3']

State array of shape (neurons, 3) containing [v, v_slow, i_app].

required
**kwargs Dict[str, Any]

Additional keyword arguments (unused).

{}

Returns:

Type Description
Float[Array, ' neurons']

Spike condition array of shape (neurons,). Positive values indicate v > vthr.

Source code in felice/neuron_models/fhn.py
def spike_condition(
    self,
    t: float,
    y: Float[Array, "neurons 3"],
    **kwargs: Dict[str, Any],
) -> Float[Array, " neurons"]:
    """Compute spike condition for event detection.

    A spike is triggered when this function crosses zero (v >= vthr).

    Args:
        t: Current simulation time (unused but required by event detection).
        y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
        **kwargs: Additional keyword arguments (unused).

    Returns:
        Spike condition array of shape (neurons,). Positive values indicate v > vthr.
    """
    return y[:, 0] - self.vthr

WereRabbit

Bases: Module

WereRabbit Neuron Model

The WereRabbit model implements a predator-prey dynamic with bistable switching behavior controlled by a "moon phase" parameter \(z\).

The dynamics are governed by:

\[ \begin{align} z &= tanh(\rho (u-v)) \\ \frac{du}{dt} &= z - z \alpha e^{\beta v} [1 + \gamma (0.5 - u)] - \sigma \\ \frac{dv}{dt} &= -z - z \alpha e^{\beta u} [1 + \gamma (0.5 - v)] - \sigma \end{align} \]

where \(z\) represents the "moon phase" that switches the predator-prey roles.

Attributes:

Name Type Description
alpha float

Current scaling parameter \(\alpha = I_{n0}/I_{bias}\) (default: 0.0129)

beta float

Exponential slope \(\beta = \kappa/U_t\) (default: 15.6)

gamma float

Coupling parameter \(\gamma = 26e^{-2}\)

rho float

Steepness of the tanh function \(\rho\) (default: 5)

sigma float

Fixpoint distance scaling \(\sigma\) (default: 0.6)

rtol float

Relative tolerance for the spiking fixpoint calculation.

atol float

Absolute tolerance for the spiking fixpoint calculation.

weight_u float

Input weight for the predator.

weight_v float

Input weight for the prey.

Source code in felice/neuron_models/wererabbit.py
class WereRabbit(eqx.Module):
    r"""
    WereRabbit Neuron Model

    The WereRabbit model implements a predator-prey dynamic with bistable 
    switching behavior controlled by a "moon phase" parameter $z$.

    The dynamics are governed by:

    $$
    \begin{align}
        z &= tanh(\rho (u-v)) \\
        \frac{du}{dt} &= z - z \alpha e^{\beta v} [1 + \gamma (0.5 - u)] - \sigma \\
        \frac{dv}{dt} &= -z - z \alpha e^{\beta u} [1 + \gamma (0.5 - v)] - \sigma
    \end{align}
    $$

    where $z$ represents the "moon phase" that switches the predator-prey roles.

    Attributes:
        alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
        beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
        gamma: Coupling parameter $\gamma = 26e^{-2}$
        rho: Steepness of the tanh function $\rho$ (default: 5)
        sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)

        rtol: Relative tolerance for the spiking fixpoint calculation.
        atol: Absolute tolerance for the spiking fixpoint calculation.

        weight_u: Input weight for the predator.
        weight_v: Input weight for the prey.
    """

    dtype: DTypeLike = eqx.field(static=True)
    rtol: float = eqx.field(static=True)
    atol: float = eqx.field(static=True)

    alpha: float = eqx.field(static=True)  # I_n0 / I_bias ratio
    beta: float = eqx.field(static=True)  # k / U_t (inverse thermal scale)
    gamma: float = eqx.field(static=True)  # coupling coefficient
    rho: float = eqx.field(static=True)  # tanh steepness
    sigma: float = eqx.field(static=True)  # bias scaling (s * I_bias)

    def __init__(
        self,
        *,
        atol: float = 1e-3,
        rtol: float = 1e-3,
        alpha: float = 0.0129,
        beta: float = 15.6,
        gamma: float = 0.26,
        rho: float = 5.0,
        sigma: float = 0.6,
        dtype: DTypeLike = jnp.float32,
    ):
        r"""Initialize the WereRabbit neuron model.

        Args:
            rtol: Relative tolerance for the spiking fixpoint calculation.
            atol: Absolute tolerance for the spiking fixpoint calculation.
            alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
            beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
            gamma: Coupling parameter $\gamma = 26e^{-2}$
            rho: Steepness of the tanh function $\rho$ (default: 5)
            sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
            dtype: Data type for arrays (default: float32).
        """
        self.dtype = dtype
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.rho = rho
        self.sigma = sigma

        self.rtol = rtol
        self.atol = atol

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
        """Initialize the neuron state variables.

        Args:
            n_neurons: Number of neurons to initialize.

        Returns:
            Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
            where u and v are the predator/prey membrane voltages, has_spiked is a
            variable that is 1 whenever the neuron spike and 0 otherwise .
        """
        x1 = jnp.zeros((n_neurons,), dtype=self.dtype)
        x2 = jnp.zeros((n_neurons,), dtype=self.dtype)
        return jnp.stack([x1, x2], axis=1)

    def vector_field(self, y: Float[Array, "neurons 2"]) -> Float[Array, "neurons 2"]:
        """Compute vector field of the neuron state variables.

        This implements the WereRabbit dynamics

            - du/dt: Predator dynamics
            - dv/dt: WerePrey dynamics

        Args:
            y: State array of shape (neurons, 2) containing [u, v].

        Returns:
            Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
        """
        u = y[:, 0]
        v = y[:, 1]

        z = jax.nn.tanh(self.rho * (u - v))
        du = (
            z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))
            - self.sigma
        )
        dv = (
            z
            * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))
            - self.sigma
        )

        dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)
        du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)

        return jnp.stack([du, dv], axis=1)

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 2"]:
        """Compute time derivatives of the neuron state variables.

        This implements the WereRabbit dynamics

            - du/dt: Predator dynamics
            - dv/dt: WerePrey dynamics

        Args:
            t: Current simulation time (unused but required by framework).
            y: State array of shape (neurons, 3) containing [u, v, has_spiked].
            args: Additional arguments (unused but required by framework).

        Returns:
            Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
        """
        dxdt = self.vector_field(y)

        return dxdt

    def spike_condition(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        **kwargs: Dict[str, Any],
    ) -> Float[Array, " neurons"]:
        """Compute spike condition for event detection.

        A spike is triggered when the system reach to a fixpoint.

        INFO:
            `has_spiked` is use to the system don't detect a continuos
            spike when reach a fixpoint.

        Args:
            t: Current simulation time (unused but required by the framework).
            y: State array of shape (neurons, 3) containing [u, v, has_spiked].
            **kwargs: Additional keyword arguments (unused).

        Returns:
            Spike condition array of shape (neurons,). Positive values indicate spike.
        """
        _atol = self.atol
        _rtol = self.rtol
        _norm = optx.rms_norm

        vf = self.dynamics(t, y, {})

        @jax.vmap
        def calculate_norm(vf, y):
            return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])

        base_cond = calculate_norm(vf, y)

        return base_cond
Functions
__init__(*, atol: float = 0.001, rtol: float = 0.001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 5.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)

Initialize the WereRabbit neuron model.

Parameters:

Name Type Description Default
rtol float

Relative tolerance for the spiking fixpoint calculation.

0.001
atol float

Absolute tolerance for the spiking fixpoint calculation.

0.001
alpha float

Current scaling parameter \(\alpha = I_{n0}/I_{bias}\) (default: 0.0129)

0.0129
beta float

Exponential slope \(\beta = \kappa/U_t\) (default: 15.6)

15.6
gamma float

Coupling parameter \(\gamma = 26e^{-2}\)

0.26
rho float

Steepness of the tanh function \(\rho\) (default: 5)

5.0
sigma float

Fixpoint distance scaling \(\sigma\) (default: 0.6)

0.6
dtype DTypeLike

Data type for arrays (default: float32).

float32
Source code in felice/neuron_models/wererabbit.py
def __init__(
    self,
    *,
    atol: float = 1e-3,
    rtol: float = 1e-3,
    alpha: float = 0.0129,
    beta: float = 15.6,
    gamma: float = 0.26,
    rho: float = 5.0,
    sigma: float = 0.6,
    dtype: DTypeLike = jnp.float32,
):
    r"""Initialize the WereRabbit neuron model.

    Args:
        rtol: Relative tolerance for the spiking fixpoint calculation.
        atol: Absolute tolerance for the spiking fixpoint calculation.
        alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
        beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
        gamma: Coupling parameter $\gamma = 26e^{-2}$
        rho: Steepness of the tanh function $\rho$ (default: 5)
        sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
        dtype: Data type for arrays (default: float32).
    """
    self.dtype = dtype
    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.rho = rho
    self.sigma = sigma

    self.rtol = rtol
    self.atol = atol
init_state(n_neurons: int) -> Float[Array, 'neurons 2']

Initialize the neuron state variables.

Parameters:

Name Type Description Default
n_neurons int

Number of neurons to initialize.

required

Returns:

Type Description
Float[Array, 'neurons 2']

Initial state array of shape (neurons, 3) containing [u, v, has_spiked],

Float[Array, 'neurons 2']

where u and v are the predator/prey membrane voltages, has_spiked is a

Float[Array, 'neurons 2']

variable that is 1 whenever the neuron spike and 0 otherwise .

Source code in felice/neuron_models/wererabbit.py
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
    """Initialize the neuron state variables.

    Args:
        n_neurons: Number of neurons to initialize.

    Returns:
        Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
        where u and v are the predator/prey membrane voltages, has_spiked is a
        variable that is 1 whenever the neuron spike and 0 otherwise .
    """
    x1 = jnp.zeros((n_neurons,), dtype=self.dtype)
    x2 = jnp.zeros((n_neurons,), dtype=self.dtype)
    return jnp.stack([x1, x2], axis=1)
vector_field(y: Float[Array, 'neurons 2']) -> Float[Array, 'neurons 2']

Compute vector field of the neuron state variables.

This implements the WereRabbit dynamics

- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics

Parameters:

Name Type Description Default
y Float[Array, 'neurons 2']

State array of shape (neurons, 2) containing [u, v].

required

Returns:

Type Description
Float[Array, 'neurons 2']

Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].

Source code in felice/neuron_models/wererabbit.py
def vector_field(self, y: Float[Array, "neurons 2"]) -> Float[Array, "neurons 2"]:
    """Compute vector field of the neuron state variables.

    This implements the WereRabbit dynamics

        - du/dt: Predator dynamics
        - dv/dt: WerePrey dynamics

    Args:
        y: State array of shape (neurons, 2) containing [u, v].

    Returns:
        Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
    """
    u = y[:, 0]
    v = y[:, 1]

    z = jax.nn.tanh(self.rho * (u - v))
    du = (
        z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))
        - self.sigma
    )
    dv = (
        z
        * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))
        - self.sigma
    )

    dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)
    du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)

    return jnp.stack([du, dv], axis=1)
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']

Compute time derivatives of the neuron state variables.

This implements the WereRabbit dynamics

- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics

Parameters:

Name Type Description Default
t float

Current simulation time (unused but required by framework).

required
y Float[Array, 'neurons 2']

State array of shape (neurons, 3) containing [u, v, has_spiked].

required
args Dict[str, Any]

Additional arguments (unused but required by framework).

required

Returns:

Type Description
Float[Array, 'neurons 2']

Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].

Source code in felice/neuron_models/wererabbit.py
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 2"]:
    """Compute time derivatives of the neuron state variables.

    This implements the WereRabbit dynamics

        - du/dt: Predator dynamics
        - dv/dt: WerePrey dynamics

    Args:
        t: Current simulation time (unused but required by framework).
        y: State array of shape (neurons, 3) containing [u, v, has_spiked].
        args: Additional arguments (unused but required by framework).

    Returns:
        Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
    """
    dxdt = self.vector_field(y)

    return dxdt
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']

Compute spike condition for event detection.

A spike is triggered when the system reach to a fixpoint.

INFO

has_spiked is use to the system don't detect a continuos spike when reach a fixpoint.

Parameters:

Name Type Description Default
t float

Current simulation time (unused but required by the framework).

required
y Float[Array, 'neurons 2']

State array of shape (neurons, 3) containing [u, v, has_spiked].

required
**kwargs Dict[str, Any]

Additional keyword arguments (unused).

{}

Returns:

Type Description
Float[Array, ' neurons']

Spike condition array of shape (neurons,). Positive values indicate spike.

Source code in felice/neuron_models/wererabbit.py
def spike_condition(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    **kwargs: Dict[str, Any],
) -> Float[Array, " neurons"]:
    """Compute spike condition for event detection.

    A spike is triggered when the system reach to a fixpoint.

    INFO:
        `has_spiked` is use to the system don't detect a continuos
        spike when reach a fixpoint.

    Args:
        t: Current simulation time (unused but required by the framework).
        y: State array of shape (neurons, 3) containing [u, v, has_spiked].
        **kwargs: Additional keyword arguments (unused).

    Returns:
        Spike condition array of shape (neurons,). Positive values indicate spike.
    """
    _atol = self.atol
    _rtol = self.rtol
    _norm = optx.rms_norm

    vf = self.dynamics(t, y, {})

    @jax.vmap
    def calculate_norm(vf, y):
        return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])

    base_cond = calculate_norm(vf, y)

    return base_cond