+
+
+
diff --git a/search/search_index.json b/search/search_index.json
index 12c4306..93e6176 100644
--- a/search/search_index.json
+++ b/search/search_index.json
@@ -1 +1 @@
-{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"],"fields":{"title":{"boost":1000.0},"text":{"boost":1.0},"tags":{"boost":1000000.0}}},"docs":[{"location":"","title":"Felice","text":"
This project provides a JAX implementation of the different neuron models in felice
"},{"location":"#overview","title":"Overview","text":"
The framework is built on top of EventPropJax and leverages JAX's automatic differentiation for efficient simulation and training of SNNs using event-based exact gradients.
"},{"location":"#key-features","title":"Key Features","text":"
- Delay learning
- Non-linear neuron models
- WereRabbit Neuron Model: Implementation of a dual-state oscillatory neuron model with bistable dynamics
- FHN Neuron Model
- Snowball Neuron Model
"},{"location":"#installation","title":"\ud83d\udce6 Installation","text":"
Felice uses uv for dependency management. To install:
uv sync\n
"},{"location":"#cuda-support-optional","title":"CUDA Support (Optional)","text":"
For GPU acceleration with CUDA 13:
uv sync --extra cuda\n
"},{"location":"#quick-start","title":"Quick Start","text":"
Here's a simple example using the WereRabbit neuron model:
import diffrax as dfx\nimport jax.numpy as jnp\nimport jax.random as jrand\nfrom eventpropjax.evnn import FFEvNN\nfrom felice.neuron_models import WereRabbit\n\n# Initialize random key and parameters\nkey = jrand.key(0)\nmax_time = 300e-3\n\n# Create a feedforward event-driven neural network\nsnn = FFEvNN(\n layers=[1],\n in_size=2,\n neuron_model=WereRabbit,\n solver=dfx.Tsit5(),\n max_solver_time=max_time,\n key=key,\n max_event_steps=1000000,\n solver_stepsize=1e-6,\n rtol=10.0,\n atol=0.0,\n Ibias=30e-12,\n)\n\n# Simulate with input spikes\nin_spikes = jnp.asarray([[0.0], [0.157]])\nspikes = snn.spikes_until_t(in_spikes, max_time)\n
See the examples directory for more detailed usage examples.
"},{"location":"api/","title":"API Reference","text":"
API documentation for Felice.
"},{"location":"api/#modules","title":"Modules","text":"
- Neuron Models - Neuron model implementations
- Solver - Zero-clipping solver
- Datasets - Built-in datasets
"},{"location":"api/datasets/","title":"Datasets","text":""},{"location":"api/datasets/#felice.datasets","title":"
felice.datasets","text":""},{"location":"api/neuron_models/","title":"Neuron Models","text":""},{"location":"api/neuron_models/#felice.neuron_models","title":"
felice.neuron_models","text":""},{"location":"api/neuron_models/#felice.neuron_models-classes","title":"Classes","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang","title":"
Boomerang","text":"
Bases: Module
Source code in
felice/neuron_models/boomerang.py class Boomerang(eqx.Module):\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n u0: float = eqx.field(static=True)\n v0: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n\n def vector_field(\n self, u: Float[Array, \"...\"], v: Float[Array, \"...\"]\n ) -> Tuple[Float[Array, \"...\"], Float[Array, \"...\"]]:\n alpha = self.alpha\n beta = self.beta\n gamma = self.gamma\n sigma = self.sigma\n rho = self.rho\n\n z = jax.nn.tanh(rho * (v - u))\n du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z\n dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z\n\n return du, dv\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.__init__","title":"
__init__(*, atol: float = 1e-06, rtol: float = 0.0001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 30.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)","text":"
Initialize the WereRabbit neuron model.
Parameters:
Name Type Description Default
key JAX random key for weight initialization.
required
n_neurons Number of neurons in this layer.
required
in_size Number of input connections (excluding recurrent connections).
required
wmask Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
required
rtol float Relative tolerance for the spiking fixpoint calculation.
0.0001 atol float Absolute tolerance for the spiking fixpoint calculation.
1e-06 alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
0.0129 beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
15.6 gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
0.26 rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
30.0 sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
0.6 wlim Limit for weight initialization. If None, uses init_weights.
required
wmean Mean value for weight initialization.
required
init_weights Optional initial weight values. If None, weights are randomly initialized.
required
fan_in_mode Mode for fan-in based weight initialization ('sqrt', 'linear').
required
dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in
felice/neuron_models/boomerang.py def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.init_state","title":"
init_state(n_neurons: int) -> Float[Array, 'neurons 2']","text":"
Initialize the neuron state variables.
Parameters:
Name Type Description Default
n_neurons int Number of neurons to initialize.
required
Returns:
Type Description
Float[Array, 'neurons 2'] Initial state array of shape (neurons, 3) containing [u, v],
Float[Array, 'neurons 2'] where u and v are the predator/prey membrane voltages.
Source code in
felice/neuron_models/boomerang.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.dynamics","title":"
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']","text":"
Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 2) containing [u, v].
required
args Dict[str, Any] Additional arguments (unused but required by framework).
required
Returns:
Type Description
Float[Array, 'neurons 2'] Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
Source code in
felice/neuron_models/boomerang.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.spike_condition","title":"
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"
Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO
has_spiked is use to the system don't detect a continuos spike when reach a fixpoint.
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by the framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required
**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type Description
Float[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate spike.
Source code in
felice/neuron_models/boomerang.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS","title":"
FHNRS","text":"
Bases: Module
FitzHugh-Nagumo neuron model
Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast and slow currents to produce oscillatory spiking behavior.
The dynamics are governed by:
\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\]
where the currents are:
- \\(I_{passive} = g_{max}(v - E_{rev})\\)
- \\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\)
- \\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)
References
- Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.
Attributes:
Name Type Description
reset_grad_preserve Preserve the gradient when the neuron spikes by doing a soft reset.
gmax_pasive float Maximal conductance of the passive current.
Erev_pasive float Reversal potential for the passive current.
a_fast float Amplitude parameter for the fast current dynamics.
voff_fast float Voltage offset for the fast current activation.
tau_fast float Time constant for the fast current (typically zero for instantaneous).
a_slow float Amplitude parameter for the slow current dynamics.
voff_slow float Voltage offset for the slow current activation.
tau_slow float Time constant for the slow recovery variable.
vthr float Voltage threshold for spike generation.
C float Membrane capacitance.
tsyn float Synaptic time constant for input current decay.
weights float Synaptic weight matrix of shape (in_plus_neurons, neurons).
Source code in
felice/neuron_models/fhn.py class FHNRS(eqx.Module):\n r\"\"\"FitzHugh-Nagumo neuron model\n\n Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by\n Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast\n and slow currents to produce oscillatory spiking behavior.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\\n \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\\n \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}}\n \\end{align}\n $$\n\n where the currents are:\n\n - $I_{passive} = g_{max}(v - E_{rev})$\n - $I_{fast} = a_{fast} \\tanh(v - v_{off,fast})$\n - $I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})$\n\n References:\n - Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.\n\n Attributes:\n reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.\n gmax_pasive: Maximal conductance of the passive current.\n Erev_pasive: Reversal potential for the passive current.\n a_fast: Amplitude parameter for the fast current dynamics.\n voff_fast: Voltage offset for the fast current activation.\n tau_fast: Time constant for the fast current (typically zero for instantaneous).\n a_slow: Amplitude parameter for the slow current dynamics.\n voff_slow: Voltage offset for the slow current activation.\n tau_slow: Time constant for the slow recovery variable.\n vthr: Voltage threshold for spike generation.\n C: Membrane capacitance.\n tsyn: Synaptic time constant for input current decay.\n weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).\n \"\"\"\n\n # Pasive parameters\n gmax_pasive: float = eqx.field(static=True)\n Erev_pasive: float = eqx.field(static=True)\n\n # Fast current\n a_fast: float = eqx.field(static=True)\n voff_fast: float = eqx.field(static=True)\n tau_fast: float = eqx.field(static=True)\n\n # Slow current\n a_slow: float = eqx.field(static=True)\n voff_slow: float = eqx.field(static=True)\n tau_slow: float = eqx.field(static=True)\n\n # Neuron threshold\n vthr: float = eqx.field(static=True)\n C: float = eqx.field(static=True, default=1.0)\n\n # Input synaptic time constant\n tsyn: float = eqx.field(static=True)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n ):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n\n def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.__init__","title":"
__init__(*, tsyn: Union[int, float, jnp.ndarray] = 1.0, C: Union[int, float, jnp.ndarray] = 1.0, gmax_pasive: Union[int, float, jnp.ndarray] = 1.0, Erev_pasive: Union[int, float, jnp.ndarray] = 0.0, a_fast: Union[int, float, jnp.ndarray] = -2.0, voff_fast: Union[int, float, jnp.ndarray] = 0.0, tau_fast: Union[int, float, jnp.ndarray] = 0.0, a_slow: Union[int, float, jnp.ndarray] = 2.0, voff_slow: Union[int, float, jnp.ndarray] = 0.0, tau_slow: Union[int, float, jnp.ndarray] = 50.0, vthr: Union[int, float, jnp.ndarray] = 2.0, dtype: DTypeLike = jnp.float32)","text":"
Initialize the FitzHugh-Nagumo neuron model.
Parameters:
Name Type Description Default
tsyn Union[int, float, ndarray] Synaptic time constant for input current decay. Can be scalar or per-neuron array.
1.0 C Union[int, float, ndarray] Membrane capacitance. Can be scalar or per-neuron array.
1.0 gmax_pasive Union[int, float, ndarray] Maximal conductance of passive current. Can be scalar or per-neuron array.
1.0 Erev_pasive Union[int, float, ndarray] Reversal potential for passive current. Can be scalar or per-neuron array.
0.0 a_fast Union[int, float, ndarray] Amplitude of fast current. Can be scalar or per-neuron array.
-2.0 voff_fast Union[int, float, ndarray] Voltage offset for fast current activation. Can be scalar or per-neuron array.
0.0 tau_fast Union[int, float, ndarray] Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
0.0 a_slow Union[int, float, ndarray] Amplitude of slow current. Can be scalar or per-neuron array.
2.0 voff_slow Union[int, float, ndarray] Voltage offset for slow current activation. Can be scalar or per-neuron array.
0.0 tau_slow Union[int, float, ndarray] Time constant for slow recovery variable. Can be scalar or per-neuron array.
50.0 vthr Union[int, float, ndarray] Voltage threshold for spike generation. Can be scalar or per-neuron array.
2.0 dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in
felice/neuron_models/fhn.py def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.init_state","title":"
init_state(n_neurons: int) -> Float[Array, 'neurons 3']","text":"
Initialize the neuron state variables.
Parameters:
Name Type Description Default
n_neurons int Number of neurons to initialize.
required
Returns:
Type Description
Float[Array, 'neurons 3'] Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
Float[Array, 'neurons 3'] where v is membrane voltage, v_slow is the slow recovery variable,
Float[Array, 'neurons 3'] and i_app is the applied synaptic current.
Source code in
felice/neuron_models/fhn.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_inst","title":"
IV_inst(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"
Compute instantaneous I-V relationship with fast and slow currents at rest.
Parameters:
Name Type Description Default
v Float[Array, ...] Membrane voltage.
required
Vrest float Resting voltage for both fast and slow currents (default: 0).
0 Returns:
Type Description
Float[Array, ...] Total current at voltage v with both fast and slow currents evaluated at Vrest.
Source code in
felice/neuron_models/fhn.py def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_fast","title":"
IV_fast(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"
Compute I-V relationship with fast current at voltage v and slow current at rest.
Parameters:
Name Type Description Default
v Float[Array, ...] Membrane voltage for passive and fast currents.
required
Vrest float Resting voltage for slow current (default: 0).
0 Returns:
Type Description
Float[Array, ...] Total current with fast dynamics responding to v and slow current at Vrest.
Source code in
felice/neuron_models/fhn.py def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_slow","title":"
IV_slow(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"
Compute steady-state I-V relationship with all currents at voltage v.
Parameters:
Name Type Description Default
v Float[Array, ...] Membrane voltage for all currents.
required
Vrest float Unused parameter for API consistency (default: 0).
0 Returns:
Type Description
Float[Array, ...] Total steady-state current with all currents responding to v.
Source code in
felice/neuron_models/fhn.py def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.dynamics","title":"
dynamics(t: float, y: Float[Array, 'neurons 3'], args: Dict[str, Any]) -> Float[Array, 'neurons 3']","text":"
Compute time derivatives of the neuron state variables.
This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents: - dv/dt: Fast membrane voltage dynamics - dv_slow/dt: Slow recovery variable dynamics - di_app/dt: Synaptic current decay
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by framework).
required
y Float[Array, 'neurons 3'] State array of shape (neurons, 3) containing [v, v_slow, i_app].
required
args Dict[str, Any] Additional arguments (unused but required by framework).
required
Returns:
Type Description
Float[Array, 'neurons 3'] Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
Source code in
felice/neuron_models/fhn.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.spike_condition","title":"
spike_condition(t: float, y: Float[Array, 'neurons 3'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"
Compute spike condition for event detection.
A spike is triggered when this function crosses zero (v >= vthr).
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by event detection).
required
y Float[Array, 'neurons 3'] State array of shape (neurons, 3) containing [v, v_slow, i_app].
required
**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type Description
Float[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate v > vthr.
Source code in
felice/neuron_models/fhn.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit","title":"
WereRabbit","text":"
Bases: Module
WereRabbit Neuron Model
The WereRabbit model implements a predator-prey dynamic with bistable switching behavior controlled by a \"moon phase\" parameter \\(z\\).
The dynamics are governed by:
\\[ \\begin{align} z &= tanh(\\rho (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\]
where \\(z\\) represents the \"moon phase\" that switches the predator-prey roles.
Attributes:
Name Type Description
alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
rtol float Relative tolerance for the spiking fixpoint calculation.
atol float Absolute tolerance for the spiking fixpoint calculation.
weight_u float Input weight for the predator.
weight_v float Input weight for the prey.
Source code in
felice/neuron_models/wererabbit.py class WereRabbit(eqx.Module):\n r\"\"\"\n WereRabbit Neuron Model\n\n The WereRabbit model implements a predator-prey dynamic with bistable \n switching behavior controlled by a \"moon phase\" parameter $z$.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n z &= tanh(\\rho (u-v)) \\\\\n \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\\n \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma\n \\end{align}\n $$\n\n where $z$ represents the \"moon phase\" that switches the predator-prey roles.\n\n Attributes:\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n\n weight_u: Input weight for the predator.\n weight_v: Input weight for the prey.\n \"\"\"\n\n dtype: DTypeLike = eqx.field(static=True)\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n\n def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.__init__","title":"
__init__(*, atol: float = 0.001, rtol: float = 0.001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 5.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)","text":"
Initialize the WereRabbit neuron model.
Parameters:
Name Type Description Default
rtol float Relative tolerance for the spiking fixpoint calculation.
0.001 atol float Absolute tolerance for the spiking fixpoint calculation.
0.001 alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
0.0129 beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
15.6 gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
0.26 rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
5.0 sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
0.6 dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in
felice/neuron_models/wererabbit.py def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.init_state","title":"
init_state(n_neurons: int) -> Float[Array, 'neurons 2']","text":"
Initialize the neuron state variables.
Parameters:
Name Type Description Default
n_neurons int Number of neurons to initialize.
required
Returns:
Type Description
Float[Array, 'neurons 2'] Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
Float[Array, 'neurons 2'] where u and v are the predator/prey membrane voltages, has_spiked is a
Float[Array, 'neurons 2'] variable that is 1 whenever the neuron spike and 0 otherwise .
Source code in
felice/neuron_models/wererabbit.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.vector_field","title":"
vector_field(y: Float[Array, 'neurons 2']) -> Float[Array, 'neurons 2']","text":"
Compute vector field of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n
Parameters:
Name Type Description Default
y Float[Array, 'neurons 2'] State array of shape (neurons, 2) containing [u, v].
required
Returns:
Type Description
Float[Array, 'neurons 2'] Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
Source code in
felice/neuron_models/wererabbit.py def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.dynamics","title":"
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']","text":"
Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required
args Dict[str, Any] Additional arguments (unused but required by framework).
required
Returns:
Type Description
Float[Array, 'neurons 2'] Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
Source code in
felice/neuron_models/wererabbit.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.spike_condition","title":"
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"
Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO
has_spiked is use to the system don't detect a continuos spike when reach a fixpoint.
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by the framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required
**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type Description
Float[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate spike.
Source code in
felice/neuron_models/wererabbit.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n
"},{"location":"api/solver/","title":"Solver","text":""},{"location":"api/solver/#felice.solver","title":"
felice.solver","text":""},{"location":"api/solver/#felice.solver-classes","title":"Classes","text":""},{"location":"api/solver/#felice.solver.ClipSolver","title":"
ClipSolver","text":"
Bases: Module
Source code in
felice/solver.py class ClipSolver(eqx.Module):\n solver: AbstractSolver\n\n def __getattr__(self, name):\n return getattr(self.solver, name)\n\n def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n
"},{"location":"api/solver/#felice.solver.ClipSolver-functions","title":"Functions","text":""},{"location":"api/solver/#felice.solver.ClipSolver.step","title":"
step(terms: PyTree[AbstractTerm], t0: RealScalarLike, t1: RealScalarLike, y0: Y, args: Args, solver_state: _SolverState, made_jump: BoolScalarLike) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]","text":"
Make a single step of the solver.
Each step is made over the specified interval \\([t_0, t_1]\\).
Arguments:
terms: The PyTree of terms representing the vector fields and controls. t0: The start of the interval that the step is made over. t1: The end of the interval that the step is made over. y0: The current value of the solution at t0. args: Any extra arguments passed to the vector field. solver_state: Any evolving state for the solver itself, at t0. made_jump: Whether there was a discontinuity in the vector field at t0. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.
Returns:
A tuple of several objects:
- The value of the solution at
t1. - A local error estimate made during the step. (Used by adaptive step size controllers to change the step size.) May be
None if no estimate was made. - Some dictionary of information that is passed to the solver's interpolation routine to calculate dense output. (Used with
SaveAt(ts=...) or SaveAt(dense=...).) - The value of the solver state at
t1. - An integer (corresponding to
diffrax.RESULTS) indicating whether the step happened successfully, or if (unusually) it failed for some reason.
Source code in
felice/solver.py def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n
"},{"location":"neuron_models/","title":"Neuron Models","text":"
Felice implements several non-linear neuron models for spiking neural networks.
"},{"location":"neuron_models/#available-models","title":"Available Models","text":"Model Type Key Features WereRabbit Dual-state oscillatory Bistable dynamics, predator-prey FitzHugh-Nagumo ... ... Snowball Exponential Integrate-and-Fire neuron model ..."},{"location":"neuron_models/fhn/","title":"FitzHugh-Nagumo","text":""},{"location":"neuron_models/fhn/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\]
where the currents are: - \\(I_{passive} = g_{max}(v - E_{rev})\\) - \\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\) - \\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)
"},{"location":"neuron_models/fhn/#examples","title":"Examples","text":"
See the following interactive notebook for a practical example:
- Basic Usage Example - Introduction to the FitzHugh-Nagumo model
"},{"location":"neuron_models/fhn/fhn/","title":"Example","text":"In\u00a0[\u00a0]: Copied!
import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nfrom eventpropjax.evnn import FFEvNN\n\nfrom felice.neuron_models import FHN\n
import diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from eventpropjax.evnn import FFEvNN from felice.neuron_models import FHN In\u00a0[2]: Copied!
key = jrand.key(0)\nmax_time = 200\n\nsnn = FFEvNN(\n layers=[1],\n in_size=1,\n neuron_model=FHN,\n solver=dfx.Dopri5(),\n max_solver_time=max_time,\n key=key,\n max_event_steps=1000000,\n solver_stepsize=0.1,\n init_weights=2.0,\n)\n
key = jrand.key(0) max_time = 200 snn = FFEvNN( layers=[1], in_size=1, neuron_model=FHN, solver=dfx.Dopri5(), max_solver_time=max_time, key=key, max_event_steps=1000000, solver_stepsize=0.1, init_weights=2.0, ) In\u00a0[3]: Copied!
v_range = jnp.arange(-3.1, 3, 0.1)\nVI_inst = jax.vmap(snn.neuron_model.IV_inst)(v_range)\nVI_fast = jax.vmap(snn.neuron_model.IV_fast)(v_range)\nVI_slow = jax.vmap(snn.neuron_model.IV_slow)(v_range)\n\nwith mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True)\n ax[0].plot(v_range, VI_inst)\n ax[1].plot(v_range, VI_fast)\n ax[2].plot(v_range, VI_slow)\n plt.show()\n
v_range = jnp.arange(-3.1, 3, 0.1) VI_inst = jax.vmap(snn.neuron_model.IV_inst)(v_range) VI_fast = jax.vmap(snn.neuron_model.IV_fast)(v_range) VI_slow = jax.vmap(snn.neuron_model.IV_slow)(v_range) with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True) ax[0].plot(v_range, VI_inst) ax[1].plot(v_range, VI_fast) ax[2].plot(v_range, VI_slow) plt.show() In\u00a0[4]: Copied!
in_spikes = jnp.asarray([[0.00]])\ncomp_times = jnp.linspace(0.0, max_time, 500)\nstate = snn.state_at_t(in_spikes, comp_times)\nspikes = snn.spikes_until_t(in_spikes, max_time)\n
in_spikes = jnp.asarray([[0.00]]) comp_times = jnp.linspace(0.0, max_time, 500) state = snn.state_at_t(in_spikes, comp_times) spikes = snn.spikes_until_t(in_spikes, max_time) In\u00a0[5]: Copied!
with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[0, :, 0])\n ax[0].plot(comp_times, state[0, :, 1], \"--\")\n # ax[0].plot(comp_times, state[0, :, 2], \"-.\")\n [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)]\n ax[0].set_xlabel(\"Time (ms)\")\n ax[0].legend([\"v\", \"vslow\", \"syn\", \"Spike\"])\n\n ax[1].plot(state[0, :, 0], state[0, :, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\")\n ax[1].set_xlabel(\"v\")\n ax[1].set_ylabel(\"v fast\")\n ax[1].legend()\n plt.show()\n
with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[0, :, 0]) ax[0].plot(comp_times, state[0, :, 1], \"--\") # ax[0].plot(comp_times, state[0, :, 2], \"-.\") [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)] ax[0].set_xlabel(\"Time (ms)\") ax[0].legend([\"v\", \"vslow\", \"syn\", \"Spike\"]) ax[1].plot(state[0, :, 0], state[0, :, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\") ax[1].set_xlabel(\"v\") ax[1].set_ylabel(\"v fast\") ax[1].legend() plt.show() In\u00a0[\u00a0]: Copied!
\n
"},{"location":"neuron_models/snowball/","title":"Snowball","text":""},{"location":"neuron_models/snowball/#circuit-description","title":"Circuit description","text":"
The circuit implemented for exponential integrate and fire neuron has been used from [1]. Part (a) in Fig.2 in [1] implements the exponential integrate and fire neuron. The neuron receives input currents using the input DPI filter [2]. This input current is integrated on the node Vmem by the membrane capacitance. The membrane potential leaks in the absence of an input spike which can be set by the bias Vleak. The Vmem potential node is connected to a cascoded source follower formed by the P14-15 and N5-6. A threshold voltage of the neuron can be set by the bias Vthr which is compared to the membrane potential. When the membrane potential is just near the threshold voltage, it starts the positive feedback block which exponentially increases membrane potential and causes the neuron to spike. As the neuron spikes, the membrane potential gets reset to ground and the refractory bias helps to stop the neuron from spiking during the refractory period as similar to a biological neuron. The circuit implemented for this experiment does not exercise either adaptability or needs a pulse extender as implemented in [1]. The Vdd used in the simulation is 1V. The neuron receives 5nA input pulses with a pulse width of 100\u03bcs.
Input current mirror W/l = 0.2 All other transistors W/L = 4/3
"},{"location":"neuron_models/snowball/#circuit-simulation","title":"Circuit Simulation","text":"
Fig.1 The dynamics of Exponential integrate and fire neuron. The light blue signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.
"},{"location":"neuron_models/snowball/#references","title":"References","text":"
- Rubino, Arianna, Melika Payvand, and Giacomo Indiveri. \"Ultra-low power silicon neuron circuit for extreme-edge neuromorphic intelligence.\" 2019 26th IEEE International Conference on Electronics, Circuits and Systems (ICECS). IEEE, 2019.
- Bartolozzi, Chiara, Srinjoy Mitra, and Giacomo Indiveri. \"An ultra low power current-mode filter for neuromorphic systems and biomedical signal processing.\" 2006 IEEE Biomedical Circuits and Systems Conference. IEEE, 2006.
"},{"location":"neuron_models/wererabbit/","title":"WereRabbit","text":"
The wererabbit neuron model is a two coupled oscillator that follows a predator- prey dynamic with a switching in the diagonal of the phaseplane. When the z in equation 1c represents the \u201cmoon phase\u201d, when ever it cross that threshold, the rabbit (prey) becomes the predator.
"},{"location":"neuron_models/wererabbit/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{du}{dt} &= z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - I_a \\\\ C\\frac{dv}{dt} &= -z I_{bias} + I_{n0} e^{\\kappa u / U_t} [z + 26e^{-2} (0.5 - v) z] - I_a \\\\ z &= tanh(\\rho (u-v))\\\\ I_a &= \\sigma I_{bias} \\\\ \\end{align} \\] Parameter Symbol Definition Value Capacitance C Circuit capacitance \\(0.1\\,pF\\) Bias current \\(I_{bias}\\) DC bias current for the fixpoint location \\(100\\,pA\\) Leakage current \\(I_{n0}\\) Transistor leakage current \\(0.129\\,pA\\) Subthreshold slope \\(\\kappa\\) Transistor subthreshold slope factor \\(0.39\\) Thermal voltage \\(U_t\\) Thermal voltage at room temperature \\(25\\,mV\\) Bias scale \\(\\sigma\\) Scaling factor for the distance between fixpoints \\(0.6\\) Steepness \\(\\rho\\) Tanh steepness for the moonphase \\(5\\)s"},{"location":"neuron_models/wererabbit/#abstraction","title":"Abstraction","text":"
To simplify the analysis of the model for simulation purposes, we can introduce a dimensionless time variable \\(\\tau=tI_{bias}/C\\), transforming the derivate of the equations in \\(\\frac{d}{dt}=\\frac{I_{bias}}{C}\\frac{d}{d\\tau}\\). Substituting this time transformation on equation~\\ref{eq:wererabbit:circ}
\\[ \\begin{equation} C\\frac{I_{bias}}{C}\\frac{du}{d\\tau} = z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma I_{bias} \\end{equation} \\]
And dividing by \\(I_{bias}\\) on both sides:
\\[ \\begin{equation} \\frac{du}{d\\tau} = z - \\frac{I_{n0}}{I_{bias}} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma \\end{equation} \\]
Obtaining the following set of equations:
\\[ \\begin{align} z &= tanh(\\kappa (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\] Parameter Definition Value \\(\\tau\\) \\(tI_{bias}/C\\) -- \\(\\alpha\\) \\(I_{n0}/I_{bias}\\) \\(0.0129\\) \\(\\beta\\) \\(\\kappa/U_t\\) 15.6 \\(\\gamma\\) -- \\(26e^{-2}\\) \\(\\rho\\) Tanh steepness for the moonphase 5 \\(\\sigma\\) Scaling factor for the distance between fixpoints 0.6"},{"location":"neuron_models/wererabbit/#examples","title":"Examples","text":"
See the following interactive notebook for a practical example:
- Basic Usage Example - Introduction to the WereRabbit model
"},{"location":"neuron_models/wererabbit/wererabbit/","title":"Basic example","text":"In\u00a0[1]: Copied!
import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nfrom eventpropjax.evnn import FFEvNN\n\nfrom felice.neuron_models import WereRabbit\n\njax.config.update(\"jax_enable_x64\", True)\n
import diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from eventpropjax.evnn import FFEvNN from felice.neuron_models import WereRabbit jax.config.update(\"jax_enable_x64\", True) In\u00a0[2]: Copied!
key = jrand.key(0)\nmax_time = 100\n\ninit_weights = jnp.asarray([[0.0], [0.2], [-0.1]])\nsnn = FFEvNN(\n layers=[1],\n in_size=2,\n neuron_model=WereRabbit,\n solver=dfx.Tsit5(),\n # stepsize_controller=PIDController(\n # rtol=1e-6, atol=1e-3, pcoeff=0.1, icoeff=0.3, dcoeff=0.0\n # ),\n max_solver_time=max_time,\n key=key,\n max_event_steps=10000,\n solver_stepsize=0.001,\n init_weights=init_weights,\n dtype=jnp.float64,\n)\n
key = jrand.key(0) max_time = 100 init_weights = jnp.asarray([[0.0], [0.2], [-0.1]]) snn = FFEvNN( layers=[1], in_size=2, neuron_model=WereRabbit, solver=dfx.Tsit5(), # stepsize_controller=PIDController( # rtol=1e-6, atol=1e-3, pcoeff=0.1, icoeff=0.3, dcoeff=0.0 # ), max_solver_time=max_time, key=key, max_event_steps=10000, solver_stepsize=0.001, init_weights=init_weights, dtype=jnp.float64, ) In\u00a0[3]: Copied!
in_spikes = jnp.asarray([[10, 30], [20, 21]])\ncomp_times = jnp.linspace(0.0, max_time, 2000)\nstate = snn.state_at_t(in_spikes, comp_times)\nspikes = snn.spikes_until_t(in_spikes, max_time)\n
in_spikes = jnp.asarray([[10, 30], [20, 21]]) comp_times = jnp.linspace(0.0, max_time, 2000) state = snn.state_at_t(in_spikes, comp_times) spikes = snn.spikes_until_t(in_spikes, max_time) In\u00a0[4]: Copied!
def compute_nullclines(snn, u_range, v_range, resolution=200):\n \"\"\"\n Compute nullclines\n du/dt = 0 (u-nullcline)\n dv/dt = 0 (v-nullcline)\n \"\"\"\n u_vals = jnp.linspace(u_range[0], u_range[1], resolution)\n v_vals = jnp.linspace(v_range[0], v_range[1], resolution)\n U, V = jnp.meshgrid(u_vals, v_vals)\n\n UV = jnp.stack(\n [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1\n )\n dS = snn.neuron_model.vector_field(UV)\n dU = dS[:, 0].reshape(U.shape)\n dV = dS[:, 1].reshape(V.shape)\n\n return U, V, dU, dV\n\n\ndef plot_vf(ax, snn, u_range, v_range):\n import numpy as np\n\n u_sparse = jnp.linspace(u_range[0], u_range[1], 20)\n v_sparse = jnp.linspace(v_range[0], v_range[1], 20)\n\n Us, Vs = jnp.meshgrid(u_sparse, v_sparse)\n\n U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200)\n\n UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1)\n dS = snn.neuron_model.vector_field(UVs)\n dUs = dS[:, 0].reshape(Us.shape)\n dVs = dS[:, 1].reshape(Vs.shape)\n\n # Normalize for visualization\n magnitude = np.sqrt(dUs**2 + dVs**2)\n magnitude[magnitude == 0] = 1\n dUs_norm = dUs / magnitude\n dVs_norm = dVs / magnitude\n\n # Nullclines\n ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\")\n ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\")\n\n # Vector field\n ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6)\n\n ax.set_xlabel(\"u (Prey)\")\n ax.set_ylabel(\"v (Predator)\")\n ax.set_title(\"Wererabbit: Phase Portrait\")\n ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\")\n ax.set_xlim(u_range)\n ax.set_ylim(v_range)\n ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n
def compute_nullclines(snn, u_range, v_range, resolution=200): \"\"\" Compute nullclines du/dt = 0 (u-nullcline) dv/dt = 0 (v-nullcline) \"\"\" u_vals = jnp.linspace(u_range[0], u_range[1], resolution) v_vals = jnp.linspace(v_range[0], v_range[1], resolution) U, V = jnp.meshgrid(u_vals, v_vals) UV = jnp.stack( [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1 ) dS = snn.neuron_model.vector_field(UV) dU = dS[:, 0].reshape(U.shape) dV = dS[:, 1].reshape(V.shape) return U, V, dU, dV def plot_vf(ax, snn, u_range, v_range): import numpy as np u_sparse = jnp.linspace(u_range[0], u_range[1], 20) v_sparse = jnp.linspace(v_range[0], v_range[1], 20) Us, Vs = jnp.meshgrid(u_sparse, v_sparse) U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200) UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1) dS = snn.neuron_model.vector_field(UVs) dUs = dS[:, 0].reshape(Us.shape) dVs = dS[:, 1].reshape(Vs.shape) # Normalize for visualization magnitude = np.sqrt(dUs**2 + dVs**2) magnitude[magnitude == 0] = 1 dUs_norm = dUs / magnitude dVs_norm = dVs / magnitude # Nullclines ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\") ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\") # Vector field ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6) ax.set_xlabel(\"u (Prey)\") ax.set_ylabel(\"v (Predator)\") ax.set_title(\"Wererabbit: Phase Portrait\") ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\") ax.set_xlim(u_range) ax.set_ylim(v_range) ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3) ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3) In\u00a0[5]: Copied!
with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[0, :, 0], label=\"x1\")\n ax[0].plot(comp_times, state[0, :, 1], label=\"x2\")\n [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)]\n ax[0].legend([\"x1\", \"x2\", \"Spike\"])\n\n plot_vf(ax[1], snn, [-0.2, 0.5], [-0.2, 0.5])\n ax[1].plot(state[0, :, 0], state[0, :, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\")\n ax[1].legend()\n plt.show()\n
with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[0, :, 0], label=\"x1\") ax[0].plot(comp_times, state[0, :, 1], label=\"x2\") [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)] ax[0].legend([\"x1\", \"x2\", \"Spike\"]) plot_vf(ax[1], snn, [-0.2, 0.5], [-0.2, 0.5]) ax[1].plot(state[0, :, 0], state[0, :, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\") ax[1].legend() plt.show() In\u00a0[6]: Copied!
import equinox as eqx\n\n\ndef loss_fn(model, spike_times, comp_times):\n out_states = model.state_at_t(spike_times, comp_times)\n logits = out_states[:, :, 0]\n\n return jnp.sum(logits)\n\n\nin_spikes = jnp.asarray([[0.01], [0.157]])\ncomp_times = jnp.linspace(0.0, max_time, 10)\n\nloss, gradients = eqx.filter_value_and_grad(loss_fn)(snn, in_spikes, comp_times)\nprint(\"Loss \", loss)\nprint(\"Gradients \", gradients.neuron_model.weight_u)\n
import equinox as eqx def loss_fn(model, spike_times, comp_times): out_states = model.state_at_t(spike_times, comp_times) logits = out_states[:, :, 0] return jnp.sum(logits) in_spikes = jnp.asarray([[0.01], [0.157]]) comp_times = jnp.linspace(0.0, max_time, 10) loss, gradients = eqx.filter_value_and_grad(loss_fn)(snn, in_spikes, comp_times) print(\"Loss \", loss) print(\"Gradients \", gradients.neuron_model.weight_u)
Loss 2.78585239490824\nGradients [[ 0. ]\n [-1.81404788]\n [-1.42144198]]\n
In\u00a0[\u00a0]: Copied!
\n
"}]}
\ No newline at end of file
+{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"],"fields":{"title":{"boost":1000.0},"text":{"boost":1.0},"tags":{"boost":1000000.0}}},"docs":[{"location":"","title":"Felice","text":"
This project provides a JAX implementation of the different neuron models in felice
"},{"location":"#overview","title":"Overview","text":"
The framework is built on top of diffrax and leverages JAX's automatic differentiation for efficient simulation and training of analogue models.
"},{"location":"#key-features","title":"Key Features","text":"
- Delay learning
- Non-linear neuron models
- WereRabbit Neuron Model: Implementation of a dual-state oscillatory neuron model with bistable dynamics
- FHN Neuron Model
- Snowball Neuron Model
"},{"location":"#installation","title":"\ud83d\udce6 Installation","text":"
Felice uses uv for dependency management. To install:
uv sync\n
"},{"location":"#cuda-support-optional","title":"CUDA Support (Optional)","text":"
For GPU acceleration with CUDA 13:
uv sync --extra cuda\n
See the examples directory for more detailed usage examples.
"},{"location":"api/","title":"API Reference","text":"
API documentation for Felice.
"},{"location":"api/#modules","title":"Modules","text":"
- Neuron Models - Neuron model implementations
- Solver - Zero-clipping solver
- Datasets - Built-in datasets
"},{"location":"api/datasets/","title":"Datasets","text":""},{"location":"api/datasets/#felice.datasets","title":"
felice.datasets","text":""},{"location":"api/neuron_models/","title":"Neuron Models","text":""},{"location":"api/neuron_models/#felice.neuron_models","title":"
felice.neuron_models","text":""},{"location":"api/neuron_models/#felice.neuron_models-classes","title":"Classes","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang","title":"
Boomerang","text":"
Bases: Module
Source code in
felice/neuron_models/boomerang.py class Boomerang(eqx.Module):\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n u0: float = eqx.field(static=True)\n v0: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n\n def vector_field(\n self, u: Float[Array, \"...\"], v: Float[Array, \"...\"]\n ) -> Tuple[Float[Array, \"...\"], Float[Array, \"...\"]]:\n alpha = self.alpha\n beta = self.beta\n gamma = self.gamma\n sigma = self.sigma\n rho = self.rho\n\n z = jax.nn.tanh(rho * (v - u))\n du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z\n dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z\n\n return du, dv\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.__init__","title":"
__init__(*, atol: float = 1e-06, rtol: float = 0.0001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 30.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)","text":"
Initialize the WereRabbit neuron model.
Parameters:
Name Type Description Default
key JAX random key for weight initialization.
required
n_neurons Number of neurons in this layer.
required
in_size Number of input connections (excluding recurrent connections).
required
wmask Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
required
rtol float Relative tolerance for the spiking fixpoint calculation.
0.0001 atol float Absolute tolerance for the spiking fixpoint calculation.
1e-06 alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
0.0129 beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
15.6 gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
0.26 rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
30.0 sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
0.6 wlim Limit for weight initialization. If None, uses init_weights.
required
wmean Mean value for weight initialization.
required
init_weights Optional initial weight values. If None, weights are randomly initialized.
required
fan_in_mode Mode for fan-in based weight initialization ('sqrt', 'linear').
required
dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in
felice/neuron_models/boomerang.py def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.init_state","title":"
init_state(n_neurons: int) -> Float[Array, 'neurons 2']","text":"
Initialize the neuron state variables.
Parameters:
Name Type Description Default
n_neurons int Number of neurons to initialize.
required
Returns:
Type Description
Float[Array, 'neurons 2'] Initial state array of shape (neurons, 3) containing [u, v],
Float[Array, 'neurons 2'] where u and v are the predator/prey membrane voltages.
Source code in
felice/neuron_models/boomerang.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.dynamics","title":"
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']","text":"
Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 2) containing [u, v].
required
args Dict[str, Any] Additional arguments (unused but required by framework).
required
Returns:
Type Description
Float[Array, 'neurons 2'] Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
Source code in
felice/neuron_models/boomerang.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n
"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.spike_condition","title":"
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"
Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO
has_spiked is use to the system don't detect a continuos spike when reach a fixpoint.
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by the framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required
**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type Description
Float[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate spike.
Source code in
felice/neuron_models/boomerang.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS","title":"
FHNRS","text":"
Bases: Module
FitzHugh-Nagumo neuron model
Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast and slow currents to produce oscillatory spiking behavior.
The dynamics are governed by:
\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\]
where the currents are:
- \\(I_{passive} = g_{max}(v - E_{rev})\\)
- \\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\)
- \\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)
References
- Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.
Attributes:
Name Type Description
reset_grad_preserve Preserve the gradient when the neuron spikes by doing a soft reset.
gmax_pasive float Maximal conductance of the passive current.
Erev_pasive float Reversal potential for the passive current.
a_fast float Amplitude parameter for the fast current dynamics.
voff_fast float Voltage offset for the fast current activation.
tau_fast float Time constant for the fast current (typically zero for instantaneous).
a_slow float Amplitude parameter for the slow current dynamics.
voff_slow float Voltage offset for the slow current activation.
tau_slow float Time constant for the slow recovery variable.
vthr float Voltage threshold for spike generation.
C float Membrane capacitance.
tsyn float Synaptic time constant for input current decay.
weights float Synaptic weight matrix of shape (in_plus_neurons, neurons).
Source code in
felice/neuron_models/fhn.py class FHNRS(eqx.Module):\n r\"\"\"FitzHugh-Nagumo neuron model\n\n Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by\n Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast\n and slow currents to produce oscillatory spiking behavior.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\\n \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\\n \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}}\n \\end{align}\n $$\n\n where the currents are:\n\n - $I_{passive} = g_{max}(v - E_{rev})$\n - $I_{fast} = a_{fast} \\tanh(v - v_{off,fast})$\n - $I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})$\n\n References:\n - Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.\n\n Attributes:\n reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.\n gmax_pasive: Maximal conductance of the passive current.\n Erev_pasive: Reversal potential for the passive current.\n a_fast: Amplitude parameter for the fast current dynamics.\n voff_fast: Voltage offset for the fast current activation.\n tau_fast: Time constant for the fast current (typically zero for instantaneous).\n a_slow: Amplitude parameter for the slow current dynamics.\n voff_slow: Voltage offset for the slow current activation.\n tau_slow: Time constant for the slow recovery variable.\n vthr: Voltage threshold for spike generation.\n C: Membrane capacitance.\n tsyn: Synaptic time constant for input current decay.\n weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).\n \"\"\"\n\n # Pasive parameters\n gmax_pasive: float = eqx.field(static=True)\n Erev_pasive: float = eqx.field(static=True)\n\n # Fast current\n a_fast: float = eqx.field(static=True)\n voff_fast: float = eqx.field(static=True)\n tau_fast: float = eqx.field(static=True)\n\n # Slow current\n a_slow: float = eqx.field(static=True)\n voff_slow: float = eqx.field(static=True)\n tau_slow: float = eqx.field(static=True)\n\n # Neuron threshold\n vthr: float = eqx.field(static=True)\n C: float = eqx.field(static=True, default=1.0)\n\n # Input synaptic time constant\n tsyn: float = eqx.field(static=True)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n ):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n\n def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.__init__","title":"
__init__(*, tsyn: Union[int, float, jnp.ndarray] = 1.0, C: Union[int, float, jnp.ndarray] = 1.0, gmax_pasive: Union[int, float, jnp.ndarray] = 1.0, Erev_pasive: Union[int, float, jnp.ndarray] = 0.0, a_fast: Union[int, float, jnp.ndarray] = -2.0, voff_fast: Union[int, float, jnp.ndarray] = 0.0, tau_fast: Union[int, float, jnp.ndarray] = 0.0, a_slow: Union[int, float, jnp.ndarray] = 2.0, voff_slow: Union[int, float, jnp.ndarray] = 0.0, tau_slow: Union[int, float, jnp.ndarray] = 50.0, vthr: Union[int, float, jnp.ndarray] = 2.0, dtype: DTypeLike = jnp.float32)","text":"
Initialize the FitzHugh-Nagumo neuron model.
Parameters:
Name Type Description Default
tsyn Union[int, float, ndarray] Synaptic time constant for input current decay. Can be scalar or per-neuron array.
1.0 C Union[int, float, ndarray] Membrane capacitance. Can be scalar or per-neuron array.
1.0 gmax_pasive Union[int, float, ndarray] Maximal conductance of passive current. Can be scalar or per-neuron array.
1.0 Erev_pasive Union[int, float, ndarray] Reversal potential for passive current. Can be scalar or per-neuron array.
0.0 a_fast Union[int, float, ndarray] Amplitude of fast current. Can be scalar or per-neuron array.
-2.0 voff_fast Union[int, float, ndarray] Voltage offset for fast current activation. Can be scalar or per-neuron array.
0.0 tau_fast Union[int, float, ndarray] Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
0.0 a_slow Union[int, float, ndarray] Amplitude of slow current. Can be scalar or per-neuron array.
2.0 voff_slow Union[int, float, ndarray] Voltage offset for slow current activation. Can be scalar or per-neuron array.
0.0 tau_slow Union[int, float, ndarray] Time constant for slow recovery variable. Can be scalar or per-neuron array.
50.0 vthr Union[int, float, ndarray] Voltage threshold for spike generation. Can be scalar or per-neuron array.
2.0 dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in
felice/neuron_models/fhn.py def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.init_state","title":"
init_state(n_neurons: int) -> Float[Array, 'neurons 3']","text":"
Initialize the neuron state variables.
Parameters:
Name Type Description Default
n_neurons int Number of neurons to initialize.
required
Returns:
Type Description
Float[Array, 'neurons 3'] Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
Float[Array, 'neurons 3'] where v is membrane voltage, v_slow is the slow recovery variable,
Float[Array, 'neurons 3'] and i_app is the applied synaptic current.
Source code in
felice/neuron_models/fhn.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_inst","title":"
IV_inst(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"
Compute instantaneous I-V relationship with fast and slow currents at rest.
Parameters:
Name Type Description Default
v Float[Array, ...] Membrane voltage.
required
Vrest float Resting voltage for both fast and slow currents (default: 0).
0 Returns:
Type Description
Float[Array, ...] Total current at voltage v with both fast and slow currents evaluated at Vrest.
Source code in
felice/neuron_models/fhn.py def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_fast","title":"
IV_fast(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"
Compute I-V relationship with fast current at voltage v and slow current at rest.
Parameters:
Name Type Description Default
v Float[Array, ...] Membrane voltage for passive and fast currents.
required
Vrest float Resting voltage for slow current (default: 0).
0 Returns:
Type Description
Float[Array, ...] Total current with fast dynamics responding to v and slow current at Vrest.
Source code in
felice/neuron_models/fhn.py def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_slow","title":"
IV_slow(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"
Compute steady-state I-V relationship with all currents at voltage v.
Parameters:
Name Type Description Default
v Float[Array, ...] Membrane voltage for all currents.
required
Vrest float Unused parameter for API consistency (default: 0).
0 Returns:
Type Description
Float[Array, ...] Total steady-state current with all currents responding to v.
Source code in
felice/neuron_models/fhn.py def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.dynamics","title":"
dynamics(t: float, y: Float[Array, 'neurons 3'], args: Dict[str, Any]) -> Float[Array, 'neurons 3']","text":"
Compute time derivatives of the neuron state variables.
This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents: - dv/dt: Fast membrane voltage dynamics - dv_slow/dt: Slow recovery variable dynamics - di_app/dt: Synaptic current decay
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by framework).
required
y Float[Array, 'neurons 3'] State array of shape (neurons, 3) containing [v, v_slow, i_app].
required
args Dict[str, Any] Additional arguments (unused but required by framework).
required
Returns:
Type Description
Float[Array, 'neurons 3'] Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
Source code in
felice/neuron_models/fhn.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n
"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.spike_condition","title":"
spike_condition(t: float, y: Float[Array, 'neurons 3'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"
Compute spike condition for event detection.
A spike is triggered when this function crosses zero (v >= vthr).
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by event detection).
required
y Float[Array, 'neurons 3'] State array of shape (neurons, 3) containing [v, v_slow, i_app].
required
**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type Description
Float[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate v > vthr.
Source code in
felice/neuron_models/fhn.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit","title":"
WereRabbit","text":"
Bases: Module
WereRabbit Neuron Model
The WereRabbit model implements a predator-prey dynamic with bistable switching behavior controlled by a \"moon phase\" parameter \\(z\\).
The dynamics are governed by:
\\[ \\begin{align} z &= tanh(\\rho (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\]
where \\(z\\) represents the \"moon phase\" that switches the predator-prey roles.
Attributes:
Name Type Description
alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
rtol float Relative tolerance for the spiking fixpoint calculation.
atol float Absolute tolerance for the spiking fixpoint calculation.
weight_u float Input weight for the predator.
weight_v float Input weight for the prey.
Source code in
felice/neuron_models/wererabbit.py class WereRabbit(eqx.Module):\n r\"\"\"\n WereRabbit Neuron Model\n\n The WereRabbit model implements a predator-prey dynamic with bistable \n switching behavior controlled by a \"moon phase\" parameter $z$.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n z &= tanh(\\rho (u-v)) \\\\\n \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\\n \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma\n \\end{align}\n $$\n\n where $z$ represents the \"moon phase\" that switches the predator-prey roles.\n\n Attributes:\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n\n weight_u: Input weight for the predator.\n weight_v: Input weight for the prey.\n \"\"\"\n\n dtype: DTypeLike = eqx.field(static=True)\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n\n def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.__init__","title":"
__init__(*, atol: float = 0.001, rtol: float = 0.001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 5.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)","text":"
Initialize the WereRabbit neuron model.
Parameters:
Name Type Description Default
rtol float Relative tolerance for the spiking fixpoint calculation.
0.001 atol float Absolute tolerance for the spiking fixpoint calculation.
0.001 alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
0.0129 beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
15.6 gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
0.26 rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
5.0 sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
0.6 dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in
felice/neuron_models/wererabbit.py def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.init_state","title":"
init_state(n_neurons: int) -> Float[Array, 'neurons 2']","text":"
Initialize the neuron state variables.
Parameters:
Name Type Description Default
n_neurons int Number of neurons to initialize.
required
Returns:
Type Description
Float[Array, 'neurons 2'] Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
Float[Array, 'neurons 2'] where u and v are the predator/prey membrane voltages, has_spiked is a
Float[Array, 'neurons 2'] variable that is 1 whenever the neuron spike and 0 otherwise .
Source code in
felice/neuron_models/wererabbit.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.vector_field","title":"
vector_field(y: Float[Array, 'neurons 2']) -> Float[Array, 'neurons 2']","text":"
Compute vector field of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n
Parameters:
Name Type Description Default
y Float[Array, 'neurons 2'] State array of shape (neurons, 2) containing [u, v].
required
Returns:
Type Description
Float[Array, 'neurons 2'] Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
Source code in
felice/neuron_models/wererabbit.py def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.dynamics","title":"
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']","text":"
Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required
args Dict[str, Any] Additional arguments (unused but required by framework).
required
Returns:
Type Description
Float[Array, 'neurons 2'] Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
Source code in
felice/neuron_models/wererabbit.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n
"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.spike_condition","title":"
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"
Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO
has_spiked is use to the system don't detect a continuos spike when reach a fixpoint.
Parameters:
Name Type Description Default
t float Current simulation time (unused but required by the framework).
required
y Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required
**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type Description
Float[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate spike.
Source code in
felice/neuron_models/wererabbit.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n
"},{"location":"api/solver/","title":"Solver","text":""},{"location":"api/solver/#felice.solver","title":"
felice.solver","text":""},{"location":"api/solver/#felice.solver-classes","title":"Classes","text":""},{"location":"api/solver/#felice.solver.ClipSolver","title":"
ClipSolver","text":"
Bases: Module
Source code in
felice/solver.py class ClipSolver(eqx.Module):\n solver: AbstractSolver\n\n def __getattr__(self, name):\n return getattr(self.solver, name)\n\n def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n
"},{"location":"api/solver/#felice.solver.ClipSolver-functions","title":"Functions","text":""},{"location":"api/solver/#felice.solver.ClipSolver.step","title":"
step(terms: PyTree[AbstractTerm], t0: RealScalarLike, t1: RealScalarLike, y0: Y, args: Args, solver_state: _SolverState, made_jump: BoolScalarLike) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]","text":"
Make a single step of the solver.
Each step is made over the specified interval \\([t_0, t_1]\\).
Arguments:
terms: The PyTree of terms representing the vector fields and controls. t0: The start of the interval that the step is made over. t1: The end of the interval that the step is made over. y0: The current value of the solution at t0. args: Any extra arguments passed to the vector field. solver_state: Any evolving state for the solver itself, at t0. made_jump: Whether there was a discontinuity in the vector field at t0. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.
Returns:
A tuple of several objects:
- The value of the solution at
t1. - A local error estimate made during the step. (Used by adaptive step size controllers to change the step size.) May be
None if no estimate was made. - Some dictionary of information that is passed to the solver's interpolation routine to calculate dense output. (Used with
SaveAt(ts=...) or SaveAt(dense=...).) - The value of the solver state at
t1. - An integer (corresponding to
diffrax.RESULTS) indicating whether the step happened successfully, or if (unusually) it failed for some reason.
Source code in
felice/solver.py def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n
"},{"location":"neuron_models/","title":"Neuron Models","text":"
Felice implements several non-linear neuron models for spiking neural networks.
"},{"location":"neuron_models/#available-models","title":"Available Models","text":"Model Type Key Features WereRabbit Dual-state oscillatory Bistable dynamics, predator-prey FitzHugh-Nagumo ... ... Snowball Exponential Integrate-and-Fire neuron model ... LIF Leaky Integrate-and-Fire neuron model ..."},{"location":"neuron_models/fhn/","title":"FitzHugh-Nagumo","text":""},{"location":"neuron_models/fhn/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\]
where the currents are: - \\(I_{passive} = g_{max}(v - E_{rev})\\) - \\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\) - \\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)
"},{"location":"neuron_models/fhn/#examples","title":"Examples","text":"
See the following interactive notebook for a practical example:
- Basic Usage Example - Introduction to the FitzHugh-Nagumo model
"},{"location":"neuron_models/fhn/fhn/","title":"Example","text":"In\u00a0[2]: Copied!
import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nfrom felice.neuron_models import FHNRS\n
import diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from felice.neuron_models import FHNRS In\u00a0[23]: Copied!
key = jrand.key(0)\nmax_time = 200\n\nneuron_model = FHNRS(\n gmax_pasive=2.0,\n Erev_pasive=0.0,\n a_fast=-2.0,\n voff_fast=0.0,\n tau_fast=0.0,\n a_slow=0.5,\n voff_slow=1.0,\n tau_slow=50.0,\n vthr=jnp.inf,\n)\n\n\ndef state_at_t(comp_times):\n sol = dfx.diffeqsolve(\n terms=dfx.ODETerm(neuron_model.dynamics),\n solver=dfx.Tsit5(),\n t0=0.0,\n t1=max_time,\n dt0=1e-3,\n y0=neuron_model.init_state(1)\n + jrand.uniform(key, shape=(1, 3), minval=0.1, maxval=0.5),\n saveat=dfx.SaveAt(ts=comp_times),\n max_steps=200000,\n )\n\n return sol.ts, sol.ys\n
key = jrand.key(0) max_time = 200 neuron_model = FHNRS( gmax_pasive=2.0, Erev_pasive=0.0, a_fast=-2.0, voff_fast=0.0, tau_fast=0.0, a_slow=0.5, voff_slow=1.0, tau_slow=50.0, vthr=jnp.inf, ) def state_at_t(comp_times): sol = dfx.diffeqsolve( terms=dfx.ODETerm(neuron_model.dynamics), solver=dfx.Tsit5(), t0=0.0, t1=max_time, dt0=1e-3, y0=neuron_model.init_state(1) + jrand.uniform(key, shape=(1, 3), minval=0.1, maxval=0.5), saveat=dfx.SaveAt(ts=comp_times), max_steps=200000, ) return sol.ts, sol.ys In\u00a0[24]: Copied!
v_range = jnp.arange(-3.1, 3, 0.1)\nVI_inst = jax.vmap(neuron_model.IV_inst)(v_range)\nVI_fast = jax.vmap(neuron_model.IV_fast)(v_range)\nVI_slow = jax.vmap(neuron_model.IV_slow)(v_range)\n\nwith mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True)\n ax[0].plot(v_range, VI_inst)\n ax[1].plot(v_range, VI_fast)\n ax[2].plot(v_range, VI_slow)\n plt.show()\n
v_range = jnp.arange(-3.1, 3, 0.1) VI_inst = jax.vmap(neuron_model.IV_inst)(v_range) VI_fast = jax.vmap(neuron_model.IV_fast)(v_range) VI_slow = jax.vmap(neuron_model.IV_slow)(v_range) with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True) ax[0].plot(v_range, VI_inst) ax[1].plot(v_range, VI_fast) ax[2].plot(v_range, VI_slow) plt.show() In\u00a0[25]: Copied!
comp_times = jnp.linspace(0.0, max_time, 500)\n_, state = state_at_t(comp_times)\n
comp_times = jnp.linspace(0.0, max_time, 500) _, state = state_at_t(comp_times) In\u00a0[26]: Copied!
def compute_nullclines(neuron_model, u_range, v_range, resolution=200):\n \"\"\"\n Compute nullclines\n du/dt = 0 (u-nullcline)\n dv/dt = 0 (v-nullcline)\n \"\"\"\n u_vals = jnp.linspace(u_range[0], u_range[1], resolution)\n v_vals = jnp.linspace(v_range[0], v_range[1], resolution)\n U, V = jnp.meshgrid(u_vals, v_vals)\n\n UV = jnp.stack(\n [U.reshape(-1), V.reshape(-1), jnp.zeros((resolution * resolution,))], axis=1\n )\n dS = neuron_model.dynamics(0, UV, {})\n dU = dS[:, 0].reshape(U.shape)\n dV = dS[:, 1].reshape(V.shape)\n\n return U, V, dU, dV\n\n\ndef plot_vf(ax, neuron_model, u_range, v_range):\n import numpy as np\n\n u_sparse = jnp.linspace(u_range[0], u_range[1], 30)\n v_sparse = jnp.linspace(v_range[0], v_range[1], 30)\n\n Us, Vs = jnp.meshgrid(u_sparse, v_sparse)\n\n U, V, dU, dV = compute_nullclines(neuron_model, u_range, v_range, 200)\n\n UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((30 * 30,))], axis=1)\n dS = neuron_model.dynamics(0, UVs, {})\n dUs = dS[:, 0].reshape(Us.shape)\n dVs = dS[:, 1].reshape(Vs.shape)\n\n # Normalize for visualization\n magnitude = np.sqrt(dUs**2 + dVs**2)\n magnitude[magnitude == 0] = 1\n dUs_norm = dUs / magnitude\n dVs_norm = dVs / magnitude\n\n # Nullclines\n ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\")\n ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\")\n\n # Vector field\n ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6)\n\n ax.set_xlabel(\"v\")\n ax.set_ylabel(\"w\")\n ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\")\n ax.set_xlim(u_range)\n ax.set_ylim(v_range)\n ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n def compute_nullclines(neuron_model, u_range, v_range, resolution=200): \"\"\" Compute nullclines du/dt = 0 (u-nullcline) dv/dt = 0 (v-nullcline) \"\"\" u_vals = jnp.linspace(u_range[0], u_range[1], resolution) v_vals = jnp.linspace(v_range[0], v_range[1], resolution) U, V = jnp.meshgrid(u_vals, v_vals) UV = jnp.stack( [U.reshape(-1), V.reshape(-1), jnp.zeros((resolution * resolution,))], axis=1 ) dS = neuron_model.dynamics(0, UV, {}) dU = dS[:, 0].reshape(U.shape) dV = dS[:, 1].reshape(V.shape) return U, V, dU, dV def plot_vf(ax, neuron_model, u_range, v_range): import numpy as np u_sparse = jnp.linspace(u_range[0], u_range[1], 30) v_sparse = jnp.linspace(v_range[0], v_range[1], 30) Us, Vs = jnp.meshgrid(u_sparse, v_sparse) U, V, dU, dV = compute_nullclines(neuron_model, u_range, v_range, 200) UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((30 * 30,))], axis=1) dS = neuron_model.dynamics(0, UVs, {}) dUs = dS[:, 0].reshape(Us.shape) dVs = dS[:, 1].reshape(Vs.shape) # Normalize for visualization magnitude = np.sqrt(dUs**2 + dVs**2) magnitude[magnitude == 0] = 1 dUs_norm = dUs / magnitude dVs_norm = dVs / magnitude # Nullclines ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\") ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\") # Vector field ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6) ax.set_xlabel(\"v\") ax.set_ylabel(\"w\") ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\") ax.set_xlim(u_range) ax.set_ylim(v_range) ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3) ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3) In\u00a0[27]: Copied!
with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[:, 0, 0])\n ax[0].plot(comp_times, state[:, 0, 1], \"--\")\n # ax[0].plot(comp_times, state[0, :, 2], \"-.\")\n ax[0].set_xlabel(\"Time (ms)\")\n ax[0].legend([\"v\", \"vslow\", \"syn\"])\n\n plot_vf(ax[1], neuron_model, [-2, 2], [-2, 2])\n\n ax[1].plot(state[:, 0, 0], state[:, 0, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\")\n ax[1].set_xlabel(\"v\")\n ax[1].set_ylabel(\"v fast\")\n ax[1].legend()\n plt.show()\n
with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[:, 0, 0]) ax[0].plot(comp_times, state[:, 0, 1], \"--\") # ax[0].plot(comp_times, state[0, :, 2], \"-.\") ax[0].set_xlabel(\"Time (ms)\") ax[0].legend([\"v\", \"vslow\", \"syn\"]) plot_vf(ax[1], neuron_model, [-2, 2], [-2, 2]) ax[1].plot(state[:, 0, 0], state[:, 0, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\") ax[1].set_xlabel(\"v\") ax[1].set_ylabel(\"v fast\") ax[1].legend() plt.show() In\u00a0[\u00a0]: Copied!
\n
In\u00a0[\u00a0]: Copied!
\n
"},{"location":"neuron_models/lif/","title":"LIF","text":""},{"location":"neuron_models/lif/#circuit-design","title":"Circuit Design","text":"
W/L = 4/3
"},{"location":"neuron_models/lif/#circuit-simulation","title":"Circuit Simulation","text":"
Fig.1 The dynamics of leaky integrate and fire neuron. The grey signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.
"},{"location":"neuron_models/lif/#referennces","title":"Referennces","text":"
- Sourikopoulos I, Hedayat S, Loyez C, Danneville F, Hoel V, Mercier E and Cappy A (2017) A 4-fJ/Spike Artificial Neuron in 65 nm CMOS Technology. Front. Neurosci. 11:123. doi: 10.3389/fnins.2017.00123
"},{"location":"neuron_models/snowball/","title":"Snowball","text":""},{"location":"neuron_models/snowball/#circuit-description","title":"Circuit description","text":"
The circuit implemented for exponential integrate and fire neuron has been used from [1]. Part (a) in Fig.2 in [1] implements the exponential integrate and fire neuron. The neuron receives input currents using the input DPI filter [2]. This input current is integrated on the node Vmem by the membrane capacitance. The membrane potential leaks in the absence of an input spike which can be set by the bias Vleak. The Vmem potential node is connected to a cascoded source follower formed by the P14-15 and N5-6. A threshold voltage of the neuron can be set by the bias Vthr which is compared to the membrane potential. When the membrane potential is just near the threshold voltage, it starts the positive feedback block which exponentially increases membrane potential and causes the neuron to spike. As the neuron spikes, the membrane potential gets reset to ground and the refractory bias helps to stop the neuron from spiking during the refractory period as similar to a biological neuron. The circuit implemented for this experiment does not exercise either adaptability or needs a pulse extender as implemented in [1]. The Vdd used in the simulation is 1V. The neuron receives 5nA input pulses with a pulse width of 100\u03bcs.
Input current mirror W/l = 0.2 All other transistors W/L = 4/3
"},{"location":"neuron_models/snowball/#circuit-simulation","title":"Circuit Simulation","text":"
Fig.1 The dynamics of Exponential integrate and fire neuron. The light blue signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.
"},{"location":"neuron_models/snowball/#references","title":"References","text":"
- Rubino, Arianna, Melika Payvand, and Giacomo Indiveri. \"Ultra-low power silicon neuron circuit for extreme-edge neuromorphic intelligence.\" 2019 26th IEEE International Conference on Electronics, Circuits and Systems (ICECS). IEEE, 2019.
- Bartolozzi, Chiara, Srinjoy Mitra, and Giacomo Indiveri. \"An ultra low power current-mode filter for neuromorphic systems and biomedical signal processing.\" 2006 IEEE Biomedical Circuits and Systems Conference. IEEE, 2006.
"},{"location":"neuron_models/wererabbit/","title":"WereRabbit","text":"
The wererabbit neuron model is a two coupled oscillator that follows a predator- prey dynamic with a switching in the diagonal of the phaseplane. When the z in equation 1c represents the \u201cmoon phase\u201d, when ever it cross that threshold, the rabbit (prey) becomes the predator.
"},{"location":"neuron_models/wererabbit/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{du}{dt} &= z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - I_a \\\\ C\\frac{dv}{dt} &= -z I_{bias} + I_{n0} e^{\\kappa u / U_t} [z + 26e^{-2} (0.5 - v) z] - I_a \\\\ z &= tanh(\\rho (u-v))\\\\ I_a &= \\sigma I_{bias} \\\\ \\end{align} \\] Parameter Symbol Definition Value Capacitance C Circuit capacitance \\(0.1\\,pF\\) Bias current \\(I_{bias}\\) DC bias current for the fixpoint location \\(100\\,pA\\) Leakage current \\(I_{n0}\\) Transistor leakage current \\(0.129\\,pA\\) Subthreshold slope \\(\\kappa\\) Transistor subthreshold slope factor \\(0.39\\) Thermal voltage \\(U_t\\) Thermal voltage at room temperature \\(25\\,mV\\) Bias scale \\(\\sigma\\) Scaling factor for the distance between fixpoints \\(0.6\\) Steepness \\(\\rho\\) Tanh steepness for the moonphase \\(5\\)s"},{"location":"neuron_models/wererabbit/#abstraction","title":"Abstraction","text":"
To simplify the analysis of the model for simulation purposes, we can introduce a dimensionless time variable \\(\\tau=tI_{bias}/C\\), transforming the derivate of the equations in \\(\\frac{d}{dt}=\\frac{I_{bias}}{C}\\frac{d}{d\\tau}\\). Substituting this time transformation on equation~\\ref{eq:wererabbit:circ}
\\[ \\begin{equation} C\\frac{I_{bias}}{C}\\frac{du}{d\\tau} = z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma I_{bias} \\end{equation} \\]
And dividing by \\(I_{bias}\\) on both sides:
\\[ \\begin{equation} \\frac{du}{d\\tau} = z - \\frac{I_{n0}}{I_{bias}} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma \\end{equation} \\]
Obtaining the following set of equations:
\\[ \\begin{align} z &= tanh(\\kappa (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\] Parameter Definition Value \\(\\tau\\) \\(tI_{bias}/C\\) -- \\(\\alpha\\) \\(I_{n0}/I_{bias}\\) \\(0.0129\\) \\(\\beta\\) \\(\\kappa/U_t\\) 15.6 \\(\\gamma\\) -- \\(26e^{-2}\\) \\(\\rho\\) Tanh steepness for the moonphase 5 \\(\\sigma\\) Scaling factor for the distance between fixpoints 0.6"},{"location":"neuron_models/wererabbit/#examples","title":"Examples","text":"
See the following interactive notebook for a practical example:
- Basic Usage Example - Introduction to the WereRabbit model
"},{"location":"neuron_models/wererabbit/wererabbit/","title":"Basic example","text":"In\u00a0[10]: Copied!
import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nfrom felice.neuron_models import WereRabbit\n\njax.config.update(\"jax_enable_x64\", True)\n
import diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from felice.neuron_models import WereRabbit jax.config.update(\"jax_enable_x64\", True) In\u00a0[11]: Copied!
key = jrand.key(0)\nmax_time = 40\n\nmodel = WereRabbit(dtype=jnp.float64)\n\n\ndef state_at_t(comp_times):\n sol = dfx.diffeqsolve(\n terms=dfx.ODETerm(model.dynamics),\n solver=dfx.Tsit5(),\n t0=0.0,\n t1=max_time,\n dt0=1e-3,\n y0=model.init_state(1)\n + jrand.uniform(key, shape=(1, 2), minval=0.1, maxval=0.5),\n saveat=dfx.SaveAt(ts=comp_times),\n max_steps=100000,\n )\n\n return sol.ts, sol.ys\n
key = jrand.key(0) max_time = 40 model = WereRabbit(dtype=jnp.float64) def state_at_t(comp_times): sol = dfx.diffeqsolve( terms=dfx.ODETerm(model.dynamics), solver=dfx.Tsit5(), t0=0.0, t1=max_time, dt0=1e-3, y0=model.init_state(1) + jrand.uniform(key, shape=(1, 2), minval=0.1, maxval=0.5), saveat=dfx.SaveAt(ts=comp_times), max_steps=100000, ) return sol.ts, sol.ys In\u00a0[12]: Copied!
comp_times = jnp.linspace(0.0, max_time, 2000)\n_, state = state_at_t(comp_times)\n
comp_times = jnp.linspace(0.0, max_time, 2000) _, state = state_at_t(comp_times) In\u00a0[13]: Copied!
def compute_nullclines(snn, u_range, v_range, resolution=200):\n \"\"\"\n Compute nullclines\n du/dt = 0 (u-nullcline)\n dv/dt = 0 (v-nullcline)\n \"\"\"\n u_vals = jnp.linspace(u_range[0], u_range[1], resolution)\n v_vals = jnp.linspace(v_range[0], v_range[1], resolution)\n U, V = jnp.meshgrid(u_vals, v_vals)\n\n UV = jnp.stack(\n [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1\n )\n dS = snn.vector_field(UV)\n dU = dS[:, 0].reshape(U.shape)\n dV = dS[:, 1].reshape(V.shape)\n\n return U, V, dU, dV\n\n\ndef plot_vf(ax, snn, u_range, v_range):\n import numpy as np\n\n u_sparse = jnp.linspace(u_range[0], u_range[1], 20)\n v_sparse = jnp.linspace(v_range[0], v_range[1], 20)\n\n Us, Vs = jnp.meshgrid(u_sparse, v_sparse)\n\n U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200)\n\n UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1)\n dS = snn.vector_field(UVs)\n dUs = dS[:, 0].reshape(Us.shape)\n dVs = dS[:, 1].reshape(Vs.shape)\n\n # Normalize for visualization\n magnitude = np.sqrt(dUs**2 + dVs**2)\n magnitude[magnitude == 0] = 1\n dUs_norm = dUs / magnitude\n dVs_norm = dVs / magnitude\n\n # Nullclines\n ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\")\n ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\")\n\n # Vector field\n ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6)\n\n ax.set_xlabel(\"u (Prey)\")\n ax.set_ylabel(\"v (Predator)\")\n ax.set_title(\"Wererabbit: Phase Portrait\")\n ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\")\n ax.set_xlim(u_range)\n ax.set_ylim(v_range)\n ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n
def compute_nullclines(snn, u_range, v_range, resolution=200): \"\"\" Compute nullclines du/dt = 0 (u-nullcline) dv/dt = 0 (v-nullcline) \"\"\" u_vals = jnp.linspace(u_range[0], u_range[1], resolution) v_vals = jnp.linspace(v_range[0], v_range[1], resolution) U, V = jnp.meshgrid(u_vals, v_vals) UV = jnp.stack( [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1 ) dS = snn.vector_field(UV) dU = dS[:, 0].reshape(U.shape) dV = dS[:, 1].reshape(V.shape) return U, V, dU, dV def plot_vf(ax, snn, u_range, v_range): import numpy as np u_sparse = jnp.linspace(u_range[0], u_range[1], 20) v_sparse = jnp.linspace(v_range[0], v_range[1], 20) Us, Vs = jnp.meshgrid(u_sparse, v_sparse) U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200) UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1) dS = snn.vector_field(UVs) dUs = dS[:, 0].reshape(Us.shape) dVs = dS[:, 1].reshape(Vs.shape) # Normalize for visualization magnitude = np.sqrt(dUs**2 + dVs**2) magnitude[magnitude == 0] = 1 dUs_norm = dUs / magnitude dVs_norm = dVs / magnitude # Nullclines ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\") ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\") # Vector field ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6) ax.set_xlabel(\"u (Prey)\") ax.set_ylabel(\"v (Predator)\") ax.set_title(\"Wererabbit: Phase Portrait\") ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\") ax.set_xlim(u_range) ax.set_ylim(v_range) ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3) ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3) In\u00a0[14]: Copied!
with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[:, 0, 0], label=\"x1\")\n ax[0].plot(comp_times, state[:, 0, 1], label=\"x2\")\n ax[0].legend([\"x1\", \"x2\"])\n\n plot_vf(ax[1], model, [-0.2, 0.5], [-0.2, 0.5])\n ax[1].plot(state[:, 0, 0], state[:, 0, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[-1, 0, 0], state[-1, 0, 1], \".\", label=\"end\")\n ax[1].legend()\n plt.show()\n
with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[:, 0, 0], label=\"x1\") ax[0].plot(comp_times, state[:, 0, 1], label=\"x2\") ax[0].legend([\"x1\", \"x2\"]) plot_vf(ax[1], model, [-0.2, 0.5], [-0.2, 0.5]) ax[1].plot(state[:, 0, 0], state[:, 0, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[-1, 0, 0], state[-1, 0, 1], \".\", label=\"end\") ax[1].legend() plt.show() In\u00a0[\u00a0]: Copied!
\n
"}]}
\ No newline at end of file