Table of Contents
Felice
This project provides a JAX implementation of the different neuron models in felice
Overview
The framework is built on top of diffrax and leverages JAX's automatic differentiation for efficient simulation and training of analogue models.
Key Features
- Delay learning
- Non-linear neuron models
- WereRabbit Neuron Model: Implementation of a dual-state oscillatory neuron model with bistable dynamics
- FHN Neuron Model
- Snowball Neuron Model
📦 Installation
Felice uses uv for dependency management. To install:
CUDA Support (Optional)
For GPU acceleration with CUDA 13:
See the examples directory for more detailed usage examples.
Neuron Models
Neuron Models
Felice implements several non-linear neuron models for spiking neural networks.
Available Models
| Model | Type | Key Features |
|---|---|---|
| WereRabbit | Dual-state oscillatory | Bistable dynamics, predator-prey |
| FitzHugh-Nagumo | ... | ... |
| Snowball | Exponential Integrate-and-Fire neuron model | ... |
| LIF | Leaky Integrate-and-Fire neuron model | ... |
WereRabbit
WereRabbit
The wererabbit neuron model is a two coupled oscillator that follows a predator- prey dynamic with a switching in the diagonal of the phaseplane. When the z in equation 1c represents the “moon phase”, when ever it cross that threshold, the rabbit (prey) becomes the predator.
Circuit equation
| Parameter | Symbol | Definition | Value |
|---|---|---|---|
| Capacitance | C | Circuit capacitance | \(0.1\,pF\) |
| Bias current | \(I_{bias}\) | DC bias current for the fixpoint location | \(100\,pA\) |
| Leakage current | \(I_{n0}\) | Transistor leakage current | \(0.129\,pA\) |
| Subthreshold slope | \(\kappa\) | Transistor subthreshold slope factor | \(0.39\) |
| Thermal voltage | \(U_t\) | Thermal voltage at room temperature | \(25\,mV\) |
| Bias scale | \(\sigma\) | Scaling factor for the distance between fixpoints | \(0.6\) |
| Steepness | \(\rho\) | Tanh steepness for the moonphase | \(5\)s |
Abstraction
To simplify the analysis of the model for simulation purposes, we can introduce a dimensionless time variable \(\tau=tI_{bias}/C\), transforming the derivate of the equations in \(\frac{d}{dt}=\frac{I_{bias}}{C}\frac{d}{d\tau}\). Substituting this time transformation on equation~\ref{eq:wererabbit:circ}
And dividing by \(I_{bias}\) on both sides:
Obtaining the following set of equations:
| Parameter | Definition | Value |
|---|---|---|
| \(\tau\) | \(tI_{bias}/C\) | -- |
| \(\alpha\) | \(I_{n0}/I_{bias}\) | \(0.0129\) |
| \(\beta\) | \(\kappa/U_t\) | 15.6 |
| \(\gamma\) | -- | \(26e^{-2}\) |
| \(\rho\) | Tanh steepness for the moonphase | 5 |
| \(\sigma\) | Scaling factor for the distance between fixpoints | 0.6 |
Examples
See the following interactive notebook for a practical example:
- Basic Usage Example - Introduction to the WereRabbit model
import diffrax as dfx
import jax
import jax.numpy as jnp
import jax.random as jrand
import matplotlib as mpl
import matplotlib.pyplot as plt
from felice.neuron_models import WereRabbit
jax.config.update("jax_enable_x64", True)
key = jrand.key(0)
max_time = 40
model = WereRabbit(dtype=jnp.float64)
def state_at_t(comp_times):
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(model.dynamics),
solver=dfx.Tsit5(),
t0=0.0,
t1=max_time,
dt0=1e-3,
y0=model.init_state(1)
+ jrand.uniform(key, shape=(1, 2), minval=0.1, maxval=0.5),
saveat=dfx.SaveAt(ts=comp_times),
max_steps=100000,
)
return sol.ts, sol.ys
comp_times = jnp.linspace(0.0, max_time, 2000)
_, state = state_at_t(comp_times)
def compute_nullclines(snn, u_range, v_range, resolution=200):
"""
Compute nullclines
du/dt = 0 (u-nullcline)
dv/dt = 0 (v-nullcline)
"""
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
U, V = jnp.meshgrid(u_vals, v_vals)
UV = jnp.stack(
[U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1
)
dS = snn.vector_field(UV)
dU = dS[:, 0].reshape(U.shape)
dV = dS[:, 1].reshape(V.shape)
return U, V, dU, dV
def plot_vf(ax, snn, u_range, v_range):
import numpy as np
u_sparse = jnp.linspace(u_range[0], u_range[1], 20)
v_sparse = jnp.linspace(v_range[0], v_range[1], 20)
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200)
UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1)
dS = snn.vector_field(UVs)
dUs = dS[:, 0].reshape(Us.shape)
dVs = dS[:, 1].reshape(Vs.shape)
# Normalize for visualization
magnitude = np.sqrt(dUs**2 + dVs**2)
magnitude[magnitude == 0] = 1
dUs_norm = dUs / magnitude
dVs_norm = dVs / magnitude
# Nullclines
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=1, linestyles="-")
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=1, linestyles="-")
# Vector field
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
ax.set_xlabel("u (Prey)")
ax.set_ylabel("v (Predator)")
ax.set_title("Wererabbit: Phase Portrait")
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
ax.set_xlim(u_range)
ax.set_ylim(v_range)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
with mpl.style.context("boilerplot.ieeetran"):
fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)
ax[0].plot(comp_times, state[:, 0, 0], label="x1")
ax[0].plot(comp_times, state[:, 0, 1], label="x2")
ax[0].legend(["x1", "x2"])
plot_vf(ax[1], model, [-0.2, 0.5], [-0.2, 0.5])
ax[1].plot(state[:, 0, 0], state[:, 0, 1])
ax[1].plot(state[0, 0, 0], state[0, 0, 1], ".", label="start")
ax[1].plot(state[-1, 0, 0], state[-1, 0, 1], ".", label="end")
ax[1].legend()
plt.show()
FitzHugh-Nagumo
FitzHugh-Nagumo
Circuit equation
where the currents are: - \(I_{passive} = g_{max}(v - E_{rev})\) - \(I_{fast} = a_{fast} \tanh(v - v_{off,fast})\) - \(I_{slow} = a_{slow} \tanh(v_{slow} - v_{off,slow})\)
Examples
See the following interactive notebook for a practical example:
- Basic Usage Example - Introduction to the FitzHugh-Nagumo model
import diffrax as dfx
import jax
import jax.numpy as jnp
import jax.random as jrand
import matplotlib as mpl
import matplotlib.pyplot as plt
from felice.neuron_models import FHNRS
key = jrand.key(0)
max_time = 200
neuron_model = FHNRS(
gmax_pasive=2.0,
Erev_pasive=0.0,
a_fast=-2.0,
voff_fast=0.0,
tau_fast=0.0,
a_slow=0.5,
voff_slow=1.0,
tau_slow=50.0,
vthr=jnp.inf,
)
def state_at_t(comp_times):
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(neuron_model.dynamics),
solver=dfx.Tsit5(),
t0=0.0,
t1=max_time,
dt0=1e-3,
y0=neuron_model.init_state(1)
+ jrand.uniform(key, shape=(1, 3), minval=0.1, maxval=0.5),
saveat=dfx.SaveAt(ts=comp_times),
max_steps=200000,
)
return sol.ts, sol.ys
v_range = jnp.arange(-3.1, 3, 0.1)
VI_inst = jax.vmap(neuron_model.IV_inst)(v_range)
VI_fast = jax.vmap(neuron_model.IV_fast)(v_range)
VI_slow = jax.vmap(neuron_model.IV_slow)(v_range)
with mpl.style.context("boilerplot.ieeetran"):
fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True)
ax[0].plot(v_range, VI_inst)
ax[1].plot(v_range, VI_fast)
ax[2].plot(v_range, VI_slow)
plt.show()
comp_times = jnp.linspace(0.0, max_time, 500)
_, state = state_at_t(comp_times)
def compute_nullclines(neuron_model, u_range, v_range, resolution=200):
"""
Compute nullclines
du/dt = 0 (u-nullcline)
dv/dt = 0 (v-nullcline)
"""
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
U, V = jnp.meshgrid(u_vals, v_vals)
UV = jnp.stack(
[U.reshape(-1), V.reshape(-1), jnp.zeros((resolution * resolution,))], axis=1
)
dS = neuron_model.dynamics(0, UV, {})
dU = dS[:, 0].reshape(U.shape)
dV = dS[:, 1].reshape(V.shape)
return U, V, dU, dV
def plot_vf(ax, neuron_model, u_range, v_range):
import numpy as np
u_sparse = jnp.linspace(u_range[0], u_range[1], 30)
v_sparse = jnp.linspace(v_range[0], v_range[1], 30)
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
U, V, dU, dV = compute_nullclines(neuron_model, u_range, v_range, 200)
UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((30 * 30,))], axis=1)
dS = neuron_model.dynamics(0, UVs, {})
dUs = dS[:, 0].reshape(Us.shape)
dVs = dS[:, 1].reshape(Vs.shape)
# Normalize for visualization
magnitude = np.sqrt(dUs**2 + dVs**2)
magnitude[magnitude == 0] = 1
dUs_norm = dUs / magnitude
dVs_norm = dVs / magnitude
# Nullclines
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=1, linestyles="-")
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=1, linestyles="-")
# Vector field
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
ax.set_xlabel("v")
ax.set_ylabel("w")
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
ax.set_xlim(u_range)
ax.set_ylim(v_range)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
with mpl.style.context("boilerplot.ieeetran"):
fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)
ax[0].plot(comp_times, state[:, 0, 0])
ax[0].plot(comp_times, state[:, 0, 1], "--")
# ax[0].plot(comp_times, state[0, :, 2], "-.")
ax[0].set_xlabel("Time (ms)")
ax[0].legend(["v", "vslow", "syn"])
plot_vf(ax[1], neuron_model, [-2, 2], [-2, 2])
ax[1].plot(state[:, 0, 0], state[:, 0, 1])
ax[1].plot(state[0, 0, 0], state[0, 0, 1], ".", label="start")
ax[1].plot(state[0, -1, 0], state[0, -1, 1], ".", label="end")
ax[1].set_xlabel("v")
ax[1].set_ylabel("v fast")
ax[1].legend()
plt.show()
Snowball
Snowball
Circuit description
The circuit implemented for exponential integrate and fire neuron has been used from [1]. Part (a) in Fig.2 in [1] implements the exponential integrate and fire neuron. The neuron receives input currents using the input DPI filter [2]. This input current is integrated on the node Vmem by the membrane capacitance. The membrane potential leaks in the absence of an input spike which can be set by the bias Vleak. The Vmem potential node is connected to a cascoded source follower formed by the P14-15 and N5-6. A threshold voltage of the neuron can be set by the bias Vthr which is compared to the membrane potential. When the membrane potential is just near the threshold voltage, it starts the positive feedback block which exponentially increases membrane potential and causes the neuron to spike. As the neuron spikes, the membrane potential gets reset to ground and the refractory bias helps to stop the neuron from spiking during the refractory period as similar to a biological neuron. The circuit implemented for this experiment does not exercise either adaptability or needs a pulse extender as implemented in [1]. The Vdd used in the simulation is 1V. The neuron receives 5nA input pulses with a pulse width of 100μs.
Input current mirror W/l = 0.2
All other transistors W/L = 4/3
Circuit Simulation
Fig.1 The dynamics of Exponential integrate and fire neuron. The light blue signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.
References
- Rubino, Arianna, Melika Payvand, and Giacomo Indiveri. "Ultra-low power silicon neuron circuit for extreme-edge neuromorphic intelligence." 2019 26th IEEE International Conference on Electronics, Circuits and Systems (ICECS). IEEE, 2019.
- Bartolozzi, Chiara, Srinjoy Mitra, and Giacomo Indiveri. "An ultra low power current-mode filter for neuromorphic systems and biomedical signal processing." 2006 IEEE Biomedical Circuits and Systems Conference. IEEE, 2006.
API Reference
API Reference
API documentation for Felice.
Modules
- Neuron Models - Neuron model implementations
- Solver - Zero-clipping solver
- Datasets - Built-in datasets
Neuron Models
felice.neuron_models
Classes
Boomerang
Bases: Module
Source code in felice/neuron_models/boomerang.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | |
Functions
__init__(*, atol: float = 1e-06, rtol: float = 0.0001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 30.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)
Initialize the WereRabbit neuron model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
JAX random key for weight initialization. |
required | |
n_neurons
|
Number of neurons in this layer. |
required | |
in_size
|
Number of input connections (excluding recurrent connections). |
required | |
wmask
|
Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons). |
required | |
rtol
|
float
|
Relative tolerance for the spiking fixpoint calculation. |
0.0001
|
atol
|
float
|
Absolute tolerance for the spiking fixpoint calculation. |
1e-06
|
alpha
|
float
|
Current scaling parameter \(\alpha = I_{n0}/I_{bias}\) (default: 0.0129) |
0.0129
|
beta
|
float
|
Exponential slope \(\beta = \kappa/U_t\) (default: 15.6) |
15.6
|
gamma
|
float
|
Coupling parameter \(\gamma = 26e^{-2}\) |
0.26
|
rho
|
float
|
Steepness of the tanh function \(\rho\) (default: 5) |
30.0
|
sigma
|
float
|
Fixpoint distance scaling \(\sigma\) (default: 0.6) |
0.6
|
wlim
|
Limit for weight initialization. If None, uses init_weights. |
required | |
wmean
|
Mean value for weight initialization. |
required | |
init_weights
|
Optional initial weight values. If None, weights are randomly initialized. |
required | |
fan_in_mode
|
Mode for fan-in based weight initialization ('sqrt', 'linear'). |
required | |
dtype
|
DTypeLike
|
Data type for arrays (default: float32). |
float32
|
Source code in felice/neuron_models/boomerang.py
init_state(n_neurons: int) -> Float[Array, 'neurons 2']
Initialize the neuron state variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_neurons
|
int
|
Number of neurons to initialize. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 2']
|
Initial state array of shape (neurons, 3) containing [u, v], |
Float[Array, 'neurons 2']
|
where u and v are the predator/prey membrane voltages. |
Source code in felice/neuron_models/boomerang.py
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']
Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current simulation time (unused but required by framework). |
required |
y
|
Float[Array, 'neurons 2']
|
State array of shape (neurons, 2) containing [u, v]. |
required |
args
|
Dict[str, Any]
|
Additional arguments (unused but required by framework). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 2']
|
Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt]. |
Source code in felice/neuron_models/boomerang.py
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']
Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO
has_spiked is use to the system don't detect a continuos
spike when reach a fixpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current simulation time (unused but required by the framework). |
required |
y
|
Float[Array, 'neurons 2']
|
State array of shape (neurons, 3) containing [u, v, has_spiked]. |
required |
**kwargs
|
Dict[str, Any]
|
Additional keyword arguments (unused). |
{}
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' neurons']
|
Spike condition array of shape (neurons,). Positive values indicate spike. |
Source code in felice/neuron_models/boomerang.py
FHNRS
Bases: Module
FitzHugh-Nagumo neuron model
Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast and slow currents to produce oscillatory spiking behavior.
The dynamics are governed by:
where the currents are:
- \(I_{passive} = g_{max}(v - E_{rev})\)
- \(I_{fast} = a_{fast} \tanh(v - v_{off,fast})\)
- \(I_{slow} = a_{slow} \tanh(v_{slow} - v_{off,slow})\)
References
- Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.
Attributes:
| Name | Type | Description |
|---|---|---|
reset_grad_preserve |
Preserve the gradient when the neuron spikes by doing a soft reset. |
|
gmax_pasive |
float
|
Maximal conductance of the passive current. |
Erev_pasive |
float
|
Reversal potential for the passive current. |
a_fast |
float
|
Amplitude parameter for the fast current dynamics. |
voff_fast |
float
|
Voltage offset for the fast current activation. |
tau_fast |
float
|
Time constant for the fast current (typically zero for instantaneous). |
a_slow |
float
|
Amplitude parameter for the slow current dynamics. |
voff_slow |
float
|
Voltage offset for the slow current activation. |
tau_slow |
float
|
Time constant for the slow recovery variable. |
vthr |
float
|
Voltage threshold for spike generation. |
C |
float
|
Membrane capacitance. |
tsyn |
float
|
Synaptic time constant for input current decay. |
weights |
float
|
Synaptic weight matrix of shape (in_plus_neurons, neurons). |
Source code in felice/neuron_models/fhn.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | |
Functions
__init__(*, tsyn: Union[int, float, jnp.ndarray] = 1.0, C: Union[int, float, jnp.ndarray] = 1.0, gmax_pasive: Union[int, float, jnp.ndarray] = 1.0, Erev_pasive: Union[int, float, jnp.ndarray] = 0.0, a_fast: Union[int, float, jnp.ndarray] = -2.0, voff_fast: Union[int, float, jnp.ndarray] = 0.0, tau_fast: Union[int, float, jnp.ndarray] = 0.0, a_slow: Union[int, float, jnp.ndarray] = 2.0, voff_slow: Union[int, float, jnp.ndarray] = 0.0, tau_slow: Union[int, float, jnp.ndarray] = 50.0, vthr: Union[int, float, jnp.ndarray] = 2.0, dtype: DTypeLike = jnp.float32)
Initialize the FitzHugh-Nagumo neuron model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tsyn
|
Union[int, float, ndarray]
|
Synaptic time constant for input current decay. Can be scalar or per-neuron array. |
1.0
|
C
|
Union[int, float, ndarray]
|
Membrane capacitance. Can be scalar or per-neuron array. |
1.0
|
gmax_pasive
|
Union[int, float, ndarray]
|
Maximal conductance of passive current. Can be scalar or per-neuron array. |
1.0
|
Erev_pasive
|
Union[int, float, ndarray]
|
Reversal potential for passive current. Can be scalar or per-neuron array. |
0.0
|
a_fast
|
Union[int, float, ndarray]
|
Amplitude of fast current. Can be scalar or per-neuron array. |
-2.0
|
voff_fast
|
Union[int, float, ndarray]
|
Voltage offset for fast current activation. Can be scalar or per-neuron array. |
0.0
|
tau_fast
|
Union[int, float, ndarray]
|
Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array. |
0.0
|
a_slow
|
Union[int, float, ndarray]
|
Amplitude of slow current. Can be scalar or per-neuron array. |
2.0
|
voff_slow
|
Union[int, float, ndarray]
|
Voltage offset for slow current activation. Can be scalar or per-neuron array. |
0.0
|
tau_slow
|
Union[int, float, ndarray]
|
Time constant for slow recovery variable. Can be scalar or per-neuron array. |
50.0
|
vthr
|
Union[int, float, ndarray]
|
Voltage threshold for spike generation. Can be scalar or per-neuron array. |
2.0
|
dtype
|
DTypeLike
|
Data type for arrays (default: float32). |
float32
|
Source code in felice/neuron_models/fhn.py
init_state(n_neurons: int) -> Float[Array, 'neurons 3']
Initialize the neuron state variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_neurons
|
int
|
Number of neurons to initialize. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 3']
|
Initial state array of shape (neurons, 3) containing [v, v_slow, i_app], |
Float[Array, 'neurons 3']
|
where v is membrane voltage, v_slow is the slow recovery variable, |
Float[Array, 'neurons 3']
|
and i_app is the applied synaptic current. |
Source code in felice/neuron_models/fhn.py
IV_inst(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]
Compute instantaneous I-V relationship with fast and slow currents at rest.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
v
|
Float[Array, ...]
|
Membrane voltage. |
required |
Vrest
|
float
|
Resting voltage for both fast and slow currents (default: 0). |
0
|
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Total current at voltage v with both fast and slow currents evaluated at Vrest. |
Source code in felice/neuron_models/fhn.py
IV_fast(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]
Compute I-V relationship with fast current at voltage v and slow current at rest.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
v
|
Float[Array, ...]
|
Membrane voltage for passive and fast currents. |
required |
Vrest
|
float
|
Resting voltage for slow current (default: 0). |
0
|
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Total current with fast dynamics responding to v and slow current at Vrest. |
Source code in felice/neuron_models/fhn.py
IV_slow(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]
Compute steady-state I-V relationship with all currents at voltage v.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
v
|
Float[Array, ...]
|
Membrane voltage for all currents. |
required |
Vrest
|
float
|
Unused parameter for API consistency (default: 0). |
0
|
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Total steady-state current with all currents responding to v. |
Source code in felice/neuron_models/fhn.py
dynamics(t: float, y: Float[Array, 'neurons 3'], args: Dict[str, Any]) -> Float[Array, 'neurons 3']
Compute time derivatives of the neuron state variables.
This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents: - dv/dt: Fast membrane voltage dynamics - dv_slow/dt: Slow recovery variable dynamics - di_app/dt: Synaptic current decay
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current simulation time (unused but required by framework). |
required |
y
|
Float[Array, 'neurons 3']
|
State array of shape (neurons, 3) containing [v, v_slow, i_app]. |
required |
args
|
Dict[str, Any]
|
Additional arguments (unused but required by framework). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 3']
|
Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt]. |
Source code in felice/neuron_models/fhn.py
spike_condition(t: float, y: Float[Array, 'neurons 3'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']
Compute spike condition for event detection.
A spike is triggered when this function crosses zero (v >= vthr).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current simulation time (unused but required by event detection). |
required |
y
|
Float[Array, 'neurons 3']
|
State array of shape (neurons, 3) containing [v, v_slow, i_app]. |
required |
**kwargs
|
Dict[str, Any]
|
Additional keyword arguments (unused). |
{}
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' neurons']
|
Spike condition array of shape (neurons,). Positive values indicate v > vthr. |
Source code in felice/neuron_models/fhn.py
WereRabbit
Bases: Module
WereRabbit Neuron Model
The WereRabbit model implements a predator-prey dynamic with bistable switching behavior controlled by a "moon phase" parameter \(z\).
The dynamics are governed by:
where \(z\) represents the "moon phase" that switches the predator-prey roles.
Attributes:
| Name | Type | Description |
|---|---|---|
alpha |
float
|
Current scaling parameter \(\alpha = I_{n0}/I_{bias}\) (default: 0.0129) |
beta |
float
|
Exponential slope \(\beta = \kappa/U_t\) (default: 15.6) |
gamma |
float
|
Coupling parameter \(\gamma = 26e^{-2}\) |
rho |
float
|
Steepness of the tanh function \(\rho\) (default: 5) |
sigma |
float
|
Fixpoint distance scaling \(\sigma\) (default: 0.6) |
rtol |
float
|
Relative tolerance for the spiking fixpoint calculation. |
atol |
float
|
Absolute tolerance for the spiking fixpoint calculation. |
weight_u |
float
|
Input weight for the predator. |
weight_v |
float
|
Input weight for the prey. |
Source code in felice/neuron_models/wererabbit.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | |
Functions
__init__(*, atol: float = 0.001, rtol: float = 0.001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 5.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)
Initialize the WereRabbit neuron model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rtol
|
float
|
Relative tolerance for the spiking fixpoint calculation. |
0.001
|
atol
|
float
|
Absolute tolerance for the spiking fixpoint calculation. |
0.001
|
alpha
|
float
|
Current scaling parameter \(\alpha = I_{n0}/I_{bias}\) (default: 0.0129) |
0.0129
|
beta
|
float
|
Exponential slope \(\beta = \kappa/U_t\) (default: 15.6) |
15.6
|
gamma
|
float
|
Coupling parameter \(\gamma = 26e^{-2}\) |
0.26
|
rho
|
float
|
Steepness of the tanh function \(\rho\) (default: 5) |
5.0
|
sigma
|
float
|
Fixpoint distance scaling \(\sigma\) (default: 0.6) |
0.6
|
dtype
|
DTypeLike
|
Data type for arrays (default: float32). |
float32
|
Source code in felice/neuron_models/wererabbit.py
init_state(n_neurons: int) -> Float[Array, 'neurons 2']
Initialize the neuron state variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_neurons
|
int
|
Number of neurons to initialize. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 2']
|
Initial state array of shape (neurons, 3) containing [u, v, has_spiked], |
Float[Array, 'neurons 2']
|
where u and v are the predator/prey membrane voltages, has_spiked is a |
Float[Array, 'neurons 2']
|
variable that is 1 whenever the neuron spike and 0 otherwise . |
Source code in felice/neuron_models/wererabbit.py
vector_field(y: Float[Array, 'neurons 2']) -> Float[Array, 'neurons 2']
Compute vector field of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Float[Array, 'neurons 2']
|
State array of shape (neurons, 2) containing [u, v]. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 2']
|
Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt]. |
Source code in felice/neuron_models/wererabbit.py
dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']
Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current simulation time (unused but required by framework). |
required |
y
|
Float[Array, 'neurons 2']
|
State array of shape (neurons, 3) containing [u, v, has_spiked]. |
required |
args
|
Dict[str, Any]
|
Additional arguments (unused but required by framework). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'neurons 2']
|
Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0]. |
Source code in felice/neuron_models/wererabbit.py
spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']
Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO
has_spiked is use to the system don't detect a continuos
spike when reach a fixpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current simulation time (unused but required by the framework). |
required |
y
|
Float[Array, 'neurons 2']
|
State array of shape (neurons, 3) containing [u, v, has_spiked]. |
required |
**kwargs
|
Dict[str, Any]
|
Additional keyword arguments (unused). |
{}
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' neurons']
|
Spike condition array of shape (neurons,). Positive values indicate spike. |
Source code in felice/neuron_models/wererabbit.py
Solver
felice.solver
Classes
ClipSolver
Bases: Module
Source code in felice/solver.py
Functions
step(terms: PyTree[AbstractTerm], t0: RealScalarLike, t1: RealScalarLike, y0: Y, args: Args, solver_state: _SolverState, made_jump: BoolScalarLike) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]
Make a single step of the solver.
Each step is made over the specified interval \([t_0, t_1]\).
Arguments:
terms: The PyTree of terms representing the vector fields and controls.t0: The start of the interval that the step is made over.t1: The end of the interval that the step is made over.y0: The current value of the solution att0.args: Any extra arguments passed to the vector field.solver_state: Any evolving state for the solver itself, att0.made_jump: Whether there was a discontinuity in the vector field att0. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.
Returns:
A tuple of several objects:
- The value of the solution at
t1. - A local error estimate made during the step. (Used by adaptive step size
controllers to change the step size.) May be
Noneif no estimate was made. - Some dictionary of information that is passed to the solver's interpolation
routine to calculate dense output. (Used with
SaveAt(ts=...)orSaveAt(dense=...).) - The value of the solver state at
t1. - An integer (corresponding to
diffrax.RESULTS) indicating whether the step happened successfully, or if (unusually) it failed for some reason.