{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"],"fields":{"title":{"boost":1000.0},"text":{"boost":1.0},"tags":{"boost":1000000.0}}},"docs":[{"location":"","title":"Felice","text":"
This project provides a JAX implementation of the different neuron models in felice
"},{"location":"#overview","title":"Overview","text":"The framework is built on top of diffrax and leverages JAX's automatic differentiation for efficient simulation and training of analogue models.
"},{"location":"#key-features","title":"Key Features","text":"Felice uses uv for dependency management. To install:
uv sync\n"},{"location":"#cuda-support-optional","title":"CUDA Support (Optional)","text":"For GPU acceleration with CUDA 13:
uv sync --extra cuda\n See the examples directory for more detailed usage examples.
"},{"location":"api/","title":"API Reference","text":"API documentation for Felice.
"},{"location":"api/#modules","title":"Modules","text":"felice.datasets","text":""},{"location":"api/neuron_models/","title":"Neuron Models","text":""},{"location":"api/neuron_models/#felice.neuron_models","title":"felice.neuron_models","text":""},{"location":"api/neuron_models/#felice.neuron_models-classes","title":"Classes","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang","title":"Boomerang","text":" Bases: Module
felice/neuron_models/boomerang.py class Boomerang(eqx.Module):\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n u0: float = eqx.field(static=True)\n v0: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n\n def vector_field(\n self, u: Float[Array, \"...\"], v: Float[Array, \"...\"]\n ) -> Tuple[Float[Array, \"...\"], Float[Array, \"...\"]]:\n alpha = self.alpha\n beta = self.beta\n gamma = self.gamma\n sigma = self.sigma\n rho = self.rho\n\n z = jax.nn.tanh(rho * (v - u))\n du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z\n dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z\n\n return du, dv\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.__init__","title":"__init__(*, atol: float = 1e-06, rtol: float = 0.0001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 30.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)","text":"Initialize the WereRabbit neuron model.
Parameters:
Name Type Description Defaultkey JAX random key for weight initialization.
requiredn_neurons Number of neurons in this layer.
requiredin_size Number of input connections (excluding recurrent connections).
requiredwmask Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
requiredrtol float Relative tolerance for the spiking fixpoint calculation.
0.0001 atol float Absolute tolerance for the spiking fixpoint calculation.
1e-06 alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
0.0129 beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
15.6 gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
0.26 rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
30.0 sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
0.6 wlim Limit for weight initialization. If None, uses init_weights.
requiredwmean Mean value for weight initialization.
requiredinit_weights Optional initial weight values. If None, weights are randomly initialized.
requiredfan_in_mode Mode for fan-in based weight initialization ('sqrt', 'linear').
requireddtype DTypeLike Data type for arrays (default: float32).
float32 Source code in felice/neuron_models/boomerang.py def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.init_state","title":"init_state(n_neurons: int) -> Float[Array, 'neurons 2']","text":"Initialize the neuron state variables.
Parameters:
Name Type Description Defaultn_neurons int Number of neurons to initialize.
requiredReturns:
Type DescriptionFloat[Array, 'neurons 2'] Initial state array of shape (neurons, 3) containing [u, v],
Float[Array, 'neurons 2'] where u and v are the predator/prey membrane voltages.
Source code infelice/neuron_models/boomerang.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.dynamics","title":"dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']","text":"Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n Parameters:
Name Type Description Defaultt float Current simulation time (unused but required by framework).
requiredy Float[Array, 'neurons 2'] State array of shape (neurons, 2) containing [u, v].
requiredargs Dict[str, Any] Additional arguments (unused but required by framework).
requiredReturns:
Type DescriptionFloat[Array, 'neurons 2'] Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
Source code infelice/neuron_models/boomerang.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.spike_condition","title":"spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFOhas_spiked is use to the system don't detect a continuos spike when reach a fixpoint.
Parameters:
Name Type Description Defaultt float Current simulation time (unused but required by the framework).
requiredy Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type DescriptionFloat[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate spike.
Source code infelice/neuron_models/boomerang.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS","title":"FHNRS","text":" Bases: Module
FitzHugh-Nagumo neuron model
Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast and slow currents to produce oscillatory spiking behavior.
The dynamics are governed by:
\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\]where the currents are:
Attributes:
Name Type Descriptionreset_grad_preserve Preserve the gradient when the neuron spikes by doing a soft reset.
gmax_pasive float Maximal conductance of the passive current.
Erev_pasive float Reversal potential for the passive current.
a_fast float Amplitude parameter for the fast current dynamics.
voff_fast float Voltage offset for the fast current activation.
tau_fast float Time constant for the fast current (typically zero for instantaneous).
a_slow float Amplitude parameter for the slow current dynamics.
voff_slow float Voltage offset for the slow current activation.
tau_slow float Time constant for the slow recovery variable.
vthr float Voltage threshold for spike generation.
C float Membrane capacitance.
tsyn float Synaptic time constant for input current decay.
weights float Synaptic weight matrix of shape (in_plus_neurons, neurons).
Source code infelice/neuron_models/fhn.py class FHNRS(eqx.Module):\n r\"\"\"FitzHugh-Nagumo neuron model\n\n Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by\n Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast\n and slow currents to produce oscillatory spiking behavior.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\\n \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\\n \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}}\n \\end{align}\n $$\n\n where the currents are:\n\n - $I_{passive} = g_{max}(v - E_{rev})$\n - $I_{fast} = a_{fast} \\tanh(v - v_{off,fast})$\n - $I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})$\n\n References:\n - Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.\n\n Attributes:\n reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.\n gmax_pasive: Maximal conductance of the passive current.\n Erev_pasive: Reversal potential for the passive current.\n a_fast: Amplitude parameter for the fast current dynamics.\n voff_fast: Voltage offset for the fast current activation.\n tau_fast: Time constant for the fast current (typically zero for instantaneous).\n a_slow: Amplitude parameter for the slow current dynamics.\n voff_slow: Voltage offset for the slow current activation.\n tau_slow: Time constant for the slow recovery variable.\n vthr: Voltage threshold for spike generation.\n C: Membrane capacitance.\n tsyn: Synaptic time constant for input current decay.\n weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).\n \"\"\"\n\n # Pasive parameters\n gmax_pasive: float = eqx.field(static=True)\n Erev_pasive: float = eqx.field(static=True)\n\n # Fast current\n a_fast: float = eqx.field(static=True)\n voff_fast: float = eqx.field(static=True)\n tau_fast: float = eqx.field(static=True)\n\n # Slow current\n a_slow: float = eqx.field(static=True)\n voff_slow: float = eqx.field(static=True)\n tau_slow: float = eqx.field(static=True)\n\n # Neuron threshold\n vthr: float = eqx.field(static=True)\n C: float = eqx.field(static=True, default=1.0)\n\n # Input synaptic time constant\n tsyn: float = eqx.field(static=True)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n ):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n\n def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.__init__","title":"__init__(*, tsyn: Union[int, float, jnp.ndarray] = 1.0, C: Union[int, float, jnp.ndarray] = 1.0, gmax_pasive: Union[int, float, jnp.ndarray] = 1.0, Erev_pasive: Union[int, float, jnp.ndarray] = 0.0, a_fast: Union[int, float, jnp.ndarray] = -2.0, voff_fast: Union[int, float, jnp.ndarray] = 0.0, tau_fast: Union[int, float, jnp.ndarray] = 0.0, a_slow: Union[int, float, jnp.ndarray] = 2.0, voff_slow: Union[int, float, jnp.ndarray] = 0.0, tau_slow: Union[int, float, jnp.ndarray] = 50.0, vthr: Union[int, float, jnp.ndarray] = 2.0, dtype: DTypeLike = jnp.float32)","text":"Initialize the FitzHugh-Nagumo neuron model.
Parameters:
Name Type Description Defaulttsyn Union[int, float, ndarray] Synaptic time constant for input current decay. Can be scalar or per-neuron array.
1.0 C Union[int, float, ndarray] Membrane capacitance. Can be scalar or per-neuron array.
1.0 gmax_pasive Union[int, float, ndarray] Maximal conductance of passive current. Can be scalar or per-neuron array.
1.0 Erev_pasive Union[int, float, ndarray] Reversal potential for passive current. Can be scalar or per-neuron array.
0.0 a_fast Union[int, float, ndarray] Amplitude of fast current. Can be scalar or per-neuron array.
-2.0 voff_fast Union[int, float, ndarray] Voltage offset for fast current activation. Can be scalar or per-neuron array.
0.0 tau_fast Union[int, float, ndarray] Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
0.0 a_slow Union[int, float, ndarray] Amplitude of slow current. Can be scalar or per-neuron array.
2.0 voff_slow Union[int, float, ndarray] Voltage offset for slow current activation. Can be scalar or per-neuron array.
0.0 tau_slow Union[int, float, ndarray] Time constant for slow recovery variable. Can be scalar or per-neuron array.
50.0 vthr Union[int, float, ndarray] Voltage threshold for spike generation. Can be scalar or per-neuron array.
2.0 dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in felice/neuron_models/fhn.py def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.init_state","title":"init_state(n_neurons: int) -> Float[Array, 'neurons 3']","text":"Initialize the neuron state variables.
Parameters:
Name Type Description Defaultn_neurons int Number of neurons to initialize.
requiredReturns:
Type DescriptionFloat[Array, 'neurons 3'] Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
Float[Array, 'neurons 3'] where v is membrane voltage, v_slow is the slow recovery variable,
Float[Array, 'neurons 3'] and i_app is the applied synaptic current.
Source code infelice/neuron_models/fhn.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_inst","title":"IV_inst(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"Compute instantaneous I-V relationship with fast and slow currents at rest.
Parameters:
Name Type Description Defaultv Float[Array, ...] Membrane voltage.
requiredVrest float Resting voltage for both fast and slow currents (default: 0).
0 Returns:
Type DescriptionFloat[Array, ...] Total current at voltage v with both fast and slow currents evaluated at Vrest.
Source code infelice/neuron_models/fhn.py def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_fast","title":"IV_fast(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"Compute I-V relationship with fast current at voltage v and slow current at rest.
Parameters:
Name Type Description Defaultv Float[Array, ...] Membrane voltage for passive and fast currents.
requiredVrest float Resting voltage for slow current (default: 0).
0 Returns:
Type DescriptionFloat[Array, ...] Total current with fast dynamics responding to v and slow current at Vrest.
Source code infelice/neuron_models/fhn.py def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_slow","title":"IV_slow(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]","text":"Compute steady-state I-V relationship with all currents at voltage v.
Parameters:
Name Type Description Defaultv Float[Array, ...] Membrane voltage for all currents.
requiredVrest float Unused parameter for API consistency (default: 0).
0 Returns:
Type DescriptionFloat[Array, ...] Total steady-state current with all currents responding to v.
Source code infelice/neuron_models/fhn.py def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.dynamics","title":"dynamics(t: float, y: Float[Array, 'neurons 3'], args: Dict[str, Any]) -> Float[Array, 'neurons 3']","text":"Compute time derivatives of the neuron state variables.
This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents: - dv/dt: Fast membrane voltage dynamics - dv_slow/dt: Slow recovery variable dynamics - di_app/dt: Synaptic current decay
Parameters:
Name Type Description Defaultt float Current simulation time (unused but required by framework).
requiredy Float[Array, 'neurons 3'] State array of shape (neurons, 3) containing [v, v_slow, i_app].
requiredargs Dict[str, Any] Additional arguments (unused but required by framework).
requiredReturns:
Type DescriptionFloat[Array, 'neurons 3'] Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
Source code infelice/neuron_models/fhn.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.spike_condition","title":"spike_condition(t: float, y: Float[Array, 'neurons 3'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"Compute spike condition for event detection.
A spike is triggered when this function crosses zero (v >= vthr).
Parameters:
Name Type Description Defaultt float Current simulation time (unused but required by event detection).
requiredy Float[Array, 'neurons 3'] State array of shape (neurons, 3) containing [v, v_slow, i_app].
required**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type DescriptionFloat[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate v > vthr.
Source code infelice/neuron_models/fhn.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit","title":"WereRabbit","text":" Bases: Module
WereRabbit Neuron Model
The WereRabbit model implements a predator-prey dynamic with bistable switching behavior controlled by a \"moon phase\" parameter \\(z\\).
The dynamics are governed by:
\\[ \\begin{align} z &= tanh(\\rho (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\]where \\(z\\) represents the \"moon phase\" that switches the predator-prey roles.
Attributes:
Name Type Descriptionalpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
rtol float Relative tolerance for the spiking fixpoint calculation.
atol float Absolute tolerance for the spiking fixpoint calculation.
weight_u float Input weight for the predator.
weight_v float Input weight for the prey.
Source code infelice/neuron_models/wererabbit.py class WereRabbit(eqx.Module):\n r\"\"\"\n WereRabbit Neuron Model\n\n The WereRabbit model implements a predator-prey dynamic with bistable \n switching behavior controlled by a \"moon phase\" parameter $z$.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n z &= tanh(\\rho (u-v)) \\\\\n \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\\n \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma\n \\end{align}\n $$\n\n where $z$ represents the \"moon phase\" that switches the predator-prey roles.\n\n Attributes:\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n\n weight_u: Input weight for the predator.\n weight_v: Input weight for the prey.\n \"\"\"\n\n dtype: DTypeLike = eqx.field(static=True)\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n\n def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.__init__","title":"__init__(*, atol: float = 0.001, rtol: float = 0.001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 5.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)","text":"Initialize the WereRabbit neuron model.
Parameters:
Name Type Description Defaultrtol float Relative tolerance for the spiking fixpoint calculation.
0.001 atol float Absolute tolerance for the spiking fixpoint calculation.
0.001 alpha float Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)
0.0129 beta float Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)
15.6 gamma float Coupling parameter \\(\\gamma = 26e^{-2}\\)
0.26 rho float Steepness of the tanh function \\(\\rho\\) (default: 5)
5.0 sigma float Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)
0.6 dtype DTypeLike Data type for arrays (default: float32).
float32 Source code in felice/neuron_models/wererabbit.py def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.init_state","title":"init_state(n_neurons: int) -> Float[Array, 'neurons 2']","text":"Initialize the neuron state variables.
Parameters:
Name Type Description Defaultn_neurons int Number of neurons to initialize.
requiredReturns:
Type DescriptionFloat[Array, 'neurons 2'] Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
Float[Array, 'neurons 2'] where u and v are the predator/prey membrane voltages, has_spiked is a
Float[Array, 'neurons 2'] variable that is 1 whenever the neuron spike and 0 otherwise .
Source code infelice/neuron_models/wererabbit.py def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.vector_field","title":"vector_field(y: Float[Array, 'neurons 2']) -> Float[Array, 'neurons 2']","text":"Compute vector field of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n Parameters:
Name Type Description Defaulty Float[Array, 'neurons 2'] State array of shape (neurons, 2) containing [u, v].
requiredReturns:
Type DescriptionFloat[Array, 'neurons 2'] Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
Source code infelice/neuron_models/wererabbit.py def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.dynamics","title":"dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']","text":"Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n Parameters:
Name Type Description Defaultt float Current simulation time (unused but required by framework).
requiredy Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
requiredargs Dict[str, Any] Additional arguments (unused but required by framework).
requiredReturns:
Type DescriptionFloat[Array, 'neurons 2'] Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
Source code infelice/neuron_models/wererabbit.py def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.spike_condition","title":"spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']","text":"Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFOhas_spiked is use to the system don't detect a continuos spike when reach a fixpoint.
Parameters:
Name Type Description Defaultt float Current simulation time (unused but required by the framework).
requiredy Float[Array, 'neurons 2'] State array of shape (neurons, 3) containing [u, v, has_spiked].
required**kwargs Dict[str, Any] Additional keyword arguments (unused).
{} Returns:
Type DescriptionFloat[Array, ' neurons'] Spike condition array of shape (neurons,). Positive values indicate spike.
Source code infelice/neuron_models/wererabbit.py def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n"},{"location":"api/solver/","title":"Solver","text":""},{"location":"api/solver/#felice.solver","title":"felice.solver","text":""},{"location":"api/solver/#felice.solver-classes","title":"Classes","text":""},{"location":"api/solver/#felice.solver.ClipSolver","title":"ClipSolver","text":" Bases: Module
felice/solver.py class ClipSolver(eqx.Module):\n solver: AbstractSolver\n\n def __getattr__(self, name):\n return getattr(self.solver, name)\n\n def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n"},{"location":"api/solver/#felice.solver.ClipSolver-functions","title":"Functions","text":""},{"location":"api/solver/#felice.solver.ClipSolver.step","title":"step(terms: PyTree[AbstractTerm], t0: RealScalarLike, t1: RealScalarLike, y0: Y, args: Args, solver_state: _SolverState, made_jump: BoolScalarLike) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]","text":"Make a single step of the solver.
Each step is made over the specified interval \\([t_0, t_1]\\).
Arguments:
terms: The PyTree of terms representing the vector fields and controls.t0: The start of the interval that the step is made over.t1: The end of the interval that the step is made over.y0: The current value of the solution at t0.args: Any extra arguments passed to the vector field.solver_state: Any evolving state for the solver itself, at t0.made_jump: Whether there was a discontinuity in the vector field at t0. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.Returns:
A tuple of several objects:
t1.None if no estimate was made.SaveAt(ts=...) or SaveAt(dense=...).)t1.diffrax.RESULTS) indicating whether the step happened successfully, or if (unusually) it failed for some reason.felice/solver.py def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n"},{"location":"neuron_models/","title":"Neuron Models","text":"Felice implements several non-linear neuron models for spiking neural networks.
"},{"location":"neuron_models/#available-models","title":"Available Models","text":"Model Type Key Features WereRabbit Dual-state oscillatory Bistable dynamics, predator-prey FitzHugh-Nagumo ... ... Snowball Exponential Integrate-and-Fire neuron model ... LIF Leaky Integrate-and-Fire neuron model ..."},{"location":"neuron_models/fhn/","title":"FitzHugh-Nagumo","text":""},{"location":"neuron_models/fhn/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\]where the currents are: - \\(I_{passive} = g_{max}(v - E_{rev})\\) - \\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\) - \\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)
"},{"location":"neuron_models/fhn/#examples","title":"Examples","text":"See the following interactive notebook for a practical example:
import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nfrom felice.neuron_models import FHNRS\nimport diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from felice.neuron_models import FHNRS In\u00a0[23]: Copied!
key = jrand.key(0)\nmax_time = 200\n\nneuron_model = FHNRS(\n gmax_pasive=2.0,\n Erev_pasive=0.0,\n a_fast=-2.0,\n voff_fast=0.0,\n tau_fast=0.0,\n a_slow=0.5,\n voff_slow=1.0,\n tau_slow=50.0,\n vthr=jnp.inf,\n)\n\n\ndef state_at_t(comp_times):\n sol = dfx.diffeqsolve(\n terms=dfx.ODETerm(neuron_model.dynamics),\n solver=dfx.Tsit5(),\n t0=0.0,\n t1=max_time,\n dt0=1e-3,\n y0=neuron_model.init_state(1)\n + jrand.uniform(key, shape=(1, 3), minval=0.1, maxval=0.5),\n saveat=dfx.SaveAt(ts=comp_times),\n max_steps=200000,\n )\n\n return sol.ts, sol.ys\nkey = jrand.key(0) max_time = 200 neuron_model = FHNRS( gmax_pasive=2.0, Erev_pasive=0.0, a_fast=-2.0, voff_fast=0.0, tau_fast=0.0, a_slow=0.5, voff_slow=1.0, tau_slow=50.0, vthr=jnp.inf, ) def state_at_t(comp_times): sol = dfx.diffeqsolve( terms=dfx.ODETerm(neuron_model.dynamics), solver=dfx.Tsit5(), t0=0.0, t1=max_time, dt0=1e-3, y0=neuron_model.init_state(1) + jrand.uniform(key, shape=(1, 3), minval=0.1, maxval=0.5), saveat=dfx.SaveAt(ts=comp_times), max_steps=200000, ) return sol.ts, sol.ys In\u00a0[24]: Copied!
v_range = jnp.arange(-3.1, 3, 0.1)\nVI_inst = jax.vmap(neuron_model.IV_inst)(v_range)\nVI_fast = jax.vmap(neuron_model.IV_fast)(v_range)\nVI_slow = jax.vmap(neuron_model.IV_slow)(v_range)\n\nwith mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True)\n ax[0].plot(v_range, VI_inst)\n ax[1].plot(v_range, VI_fast)\n ax[2].plot(v_range, VI_slow)\n plt.show()\nv_range = jnp.arange(-3.1, 3, 0.1) VI_inst = jax.vmap(neuron_model.IV_inst)(v_range) VI_fast = jax.vmap(neuron_model.IV_fast)(v_range) VI_slow = jax.vmap(neuron_model.IV_slow)(v_range) with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True) ax[0].plot(v_range, VI_inst) ax[1].plot(v_range, VI_fast) ax[2].plot(v_range, VI_slow) plt.show() In\u00a0[25]: Copied!
comp_times = jnp.linspace(0.0, max_time, 500)\n_, state = state_at_t(comp_times)\ncomp_times = jnp.linspace(0.0, max_time, 500) _, state = state_at_t(comp_times) In\u00a0[26]: Copied!
def compute_nullclines(neuron_model, u_range, v_range, resolution=200):\n \"\"\"\n Compute nullclines\n du/dt = 0 (u-nullcline)\n dv/dt = 0 (v-nullcline)\n \"\"\"\n u_vals = jnp.linspace(u_range[0], u_range[1], resolution)\n v_vals = jnp.linspace(v_range[0], v_range[1], resolution)\n U, V = jnp.meshgrid(u_vals, v_vals)\n\n UV = jnp.stack(\n [U.reshape(-1), V.reshape(-1), jnp.zeros((resolution * resolution,))], axis=1\n )\n dS = neuron_model.dynamics(0, UV, {})\n dU = dS[:, 0].reshape(U.shape)\n dV = dS[:, 1].reshape(V.shape)\n\n return U, V, dU, dV\n\n\ndef plot_vf(ax, neuron_model, u_range, v_range):\n import numpy as np\n\n u_sparse = jnp.linspace(u_range[0], u_range[1], 30)\n v_sparse = jnp.linspace(v_range[0], v_range[1], 30)\n\n Us, Vs = jnp.meshgrid(u_sparse, v_sparse)\n\n U, V, dU, dV = compute_nullclines(neuron_model, u_range, v_range, 200)\n\n UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((30 * 30,))], axis=1)\n dS = neuron_model.dynamics(0, UVs, {})\n dUs = dS[:, 0].reshape(Us.shape)\n dVs = dS[:, 1].reshape(Vs.shape)\n\n # Normalize for visualization\n magnitude = np.sqrt(dUs**2 + dVs**2)\n magnitude[magnitude == 0] = 1\n dUs_norm = dUs / magnitude\n dVs_norm = dVs / magnitude\n\n # Nullclines\n ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\")\n ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\")\n\n # Vector field\n ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6)\n\n ax.set_xlabel(\"v\")\n ax.set_ylabel(\"w\")\n ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\")\n ax.set_xlim(u_range)\n ax.set_ylim(v_range)\n ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n def compute_nullclines(neuron_model, u_range, v_range, resolution=200): \"\"\" Compute nullclines du/dt = 0 (u-nullcline) dv/dt = 0 (v-nullcline) \"\"\" u_vals = jnp.linspace(u_range[0], u_range[1], resolution) v_vals = jnp.linspace(v_range[0], v_range[1], resolution) U, V = jnp.meshgrid(u_vals, v_vals) UV = jnp.stack( [U.reshape(-1), V.reshape(-1), jnp.zeros((resolution * resolution,))], axis=1 ) dS = neuron_model.dynamics(0, UV, {}) dU = dS[:, 0].reshape(U.shape) dV = dS[:, 1].reshape(V.shape) return U, V, dU, dV def plot_vf(ax, neuron_model, u_range, v_range): import numpy as np u_sparse = jnp.linspace(u_range[0], u_range[1], 30) v_sparse = jnp.linspace(v_range[0], v_range[1], 30) Us, Vs = jnp.meshgrid(u_sparse, v_sparse) U, V, dU, dV = compute_nullclines(neuron_model, u_range, v_range, 200) UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((30 * 30,))], axis=1) dS = neuron_model.dynamics(0, UVs, {}) dUs = dS[:, 0].reshape(Us.shape) dVs = dS[:, 1].reshape(Vs.shape) # Normalize for visualization magnitude = np.sqrt(dUs**2 + dVs**2) magnitude[magnitude == 0] = 1 dUs_norm = dUs / magnitude dVs_norm = dVs / magnitude # Nullclines ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\") ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\") # Vector field ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6) ax.set_xlabel(\"v\") ax.set_ylabel(\"w\") ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\") ax.set_xlim(u_range) ax.set_ylim(v_range) ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3) ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3) In\u00a0[27]: Copied! with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[:, 0, 0])\n ax[0].plot(comp_times, state[:, 0, 1], \"--\")\n # ax[0].plot(comp_times, state[0, :, 2], \"-.\")\n ax[0].set_xlabel(\"Time (ms)\")\n ax[0].legend([\"v\", \"vslow\", \"syn\"])\n\n plot_vf(ax[1], neuron_model, [-2, 2], [-2, 2])\n\n ax[1].plot(state[:, 0, 0], state[:, 0, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\")\n ax[1].set_xlabel(\"v\")\n ax[1].set_ylabel(\"v fast\")\n ax[1].legend()\n plt.show()\nwith mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[:, 0, 0]) ax[0].plot(comp_times, state[:, 0, 1], \"--\") # ax[0].plot(comp_times, state[0, :, 2], \"-.\") ax[0].set_xlabel(\"Time (ms)\") ax[0].legend([\"v\", \"vslow\", \"syn\"]) plot_vf(ax[1], neuron_model, [-2, 2], [-2, 2]) ax[1].plot(state[:, 0, 0], state[:, 0, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\") ax[1].set_xlabel(\"v\") ax[1].set_ylabel(\"v fast\") ax[1].legend() plt.show() In\u00a0[\u00a0]: Copied!
\nIn\u00a0[\u00a0]: Copied!
\n"},{"location":"neuron_models/lif/","title":"LIF","text":""},{"location":"neuron_models/lif/#circuit-design","title":"Circuit Design","text":"
W/L = 4/3
"},{"location":"neuron_models/lif/#circuit-simulation","title":"Circuit Simulation","text":"Fig.1 The dynamics of leaky integrate and fire neuron. The grey signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.
"},{"location":"neuron_models/lif/#referennces","title":"Referennces","text":"The circuit implemented for exponential integrate and fire neuron has been used from [1]. Part (a) in Fig.2 in [1] implements the exponential integrate and fire neuron. The neuron receives input currents using the input DPI filter [2]. This input current is integrated on the node Vmem by the membrane capacitance. The membrane potential leaks in the absence of an input spike which can be set by the bias Vleak. The Vmem potential node is connected to a cascoded source follower formed by the P14-15 and N5-6. A threshold voltage of the neuron can be set by the bias Vthr which is compared to the membrane potential. When the membrane potential is just near the threshold voltage, it starts the positive feedback block which exponentially increases membrane potential and causes the neuron to spike. As the neuron spikes, the membrane potential gets reset to ground and the refractory bias helps to stop the neuron from spiking during the refractory period as similar to a biological neuron. The circuit implemented for this experiment does not exercise either adaptability or needs a pulse extender as implemented in [1]. The Vdd used in the simulation is 1V. The neuron receives 5nA input pulses with a pulse width of 100\u03bcs.
Input current mirror W/l = 0.2 All other transistors W/L = 4/3
"},{"location":"neuron_models/snowball/#circuit-simulation","title":"Circuit Simulation","text":"Fig.1 The dynamics of Exponential integrate and fire neuron. The light blue signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.
"},{"location":"neuron_models/snowball/#references","title":"References","text":"The wererabbit neuron model is a two coupled oscillator that follows a predator- prey dynamic with a switching in the diagonal of the phaseplane. When the z in equation 1c represents the \u201cmoon phase\u201d, when ever it cross that threshold, the rabbit (prey) becomes the predator.
"},{"location":"neuron_models/wererabbit/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{du}{dt} &= z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - I_a \\\\ C\\frac{dv}{dt} &= -z I_{bias} + I_{n0} e^{\\kappa u / U_t} [z + 26e^{-2} (0.5 - v) z] - I_a \\\\ z &= tanh(\\rho (u-v))\\\\ I_a &= \\sigma I_{bias} \\\\ \\end{align} \\] Parameter Symbol Definition Value Capacitance C Circuit capacitance \\(0.1\\,pF\\) Bias current \\(I_{bias}\\) DC bias current for the fixpoint location \\(100\\,pA\\) Leakage current \\(I_{n0}\\) Transistor leakage current \\(0.129\\,pA\\) Subthreshold slope \\(\\kappa\\) Transistor subthreshold slope factor \\(0.39\\) Thermal voltage \\(U_t\\) Thermal voltage at room temperature \\(25\\,mV\\) Bias scale \\(\\sigma\\) Scaling factor for the distance between fixpoints \\(0.6\\) Steepness \\(\\rho\\) Tanh steepness for the moonphase \\(5\\)s"},{"location":"neuron_models/wererabbit/#abstraction","title":"Abstraction","text":"To simplify the analysis of the model for simulation purposes, we can introduce a dimensionless time variable \\(\\tau=tI_{bias}/C\\), transforming the derivate of the equations in \\(\\frac{d}{dt}=\\frac{I_{bias}}{C}\\frac{d}{d\\tau}\\). Substituting this time transformation on equation~\\ref{eq:wererabbit:circ}
\\[ \\begin{equation} C\\frac{I_{bias}}{C}\\frac{du}{d\\tau} = z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma I_{bias} \\end{equation} \\]And dividing by \\(I_{bias}\\) on both sides:
\\[ \\begin{equation} \\frac{du}{d\\tau} = z - \\frac{I_{n0}}{I_{bias}} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma \\end{equation} \\]Obtaining the following set of equations:
\\[ \\begin{align} z &= tanh(\\kappa (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\] Parameter Definition Value \\(\\tau\\) \\(tI_{bias}/C\\) -- \\(\\alpha\\) \\(I_{n0}/I_{bias}\\) \\(0.0129\\) \\(\\beta\\) \\(\\kappa/U_t\\) 15.6 \\(\\gamma\\) -- \\(26e^{-2}\\) \\(\\rho\\) Tanh steepness for the moonphase 5 \\(\\sigma\\) Scaling factor for the distance between fixpoints 0.6"},{"location":"neuron_models/wererabbit/#examples","title":"Examples","text":"See the following interactive notebook for a practical example:
import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nfrom felice.neuron_models import WereRabbit\n\njax.config.update(\"jax_enable_x64\", True)\nimport diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from felice.neuron_models import WereRabbit jax.config.update(\"jax_enable_x64\", True) In\u00a0[11]: Copied!
key = jrand.key(0)\nmax_time = 40\n\nmodel = WereRabbit(dtype=jnp.float64)\n\n\ndef state_at_t(comp_times):\n sol = dfx.diffeqsolve(\n terms=dfx.ODETerm(model.dynamics),\n solver=dfx.Tsit5(),\n t0=0.0,\n t1=max_time,\n dt0=1e-3,\n y0=model.init_state(1)\n + jrand.uniform(key, shape=(1, 2), minval=0.1, maxval=0.5),\n saveat=dfx.SaveAt(ts=comp_times),\n max_steps=100000,\n )\n\n return sol.ts, sol.ys\nkey = jrand.key(0) max_time = 40 model = WereRabbit(dtype=jnp.float64) def state_at_t(comp_times): sol = dfx.diffeqsolve( terms=dfx.ODETerm(model.dynamics), solver=dfx.Tsit5(), t0=0.0, t1=max_time, dt0=1e-3, y0=model.init_state(1) + jrand.uniform(key, shape=(1, 2), minval=0.1, maxval=0.5), saveat=dfx.SaveAt(ts=comp_times), max_steps=100000, ) return sol.ts, sol.ys In\u00a0[12]: Copied!
comp_times = jnp.linspace(0.0, max_time, 2000)\n_, state = state_at_t(comp_times)\ncomp_times = jnp.linspace(0.0, max_time, 2000) _, state = state_at_t(comp_times) In\u00a0[13]: Copied!
def compute_nullclines(snn, u_range, v_range, resolution=200):\n \"\"\"\n Compute nullclines\n du/dt = 0 (u-nullcline)\n dv/dt = 0 (v-nullcline)\n \"\"\"\n u_vals = jnp.linspace(u_range[0], u_range[1], resolution)\n v_vals = jnp.linspace(v_range[0], v_range[1], resolution)\n U, V = jnp.meshgrid(u_vals, v_vals)\n\n UV = jnp.stack(\n [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1\n )\n dS = snn.vector_field(UV)\n dU = dS[:, 0].reshape(U.shape)\n dV = dS[:, 1].reshape(V.shape)\n\n return U, V, dU, dV\n\n\ndef plot_vf(ax, snn, u_range, v_range):\n import numpy as np\n\n u_sparse = jnp.linspace(u_range[0], u_range[1], 20)\n v_sparse = jnp.linspace(v_range[0], v_range[1], 20)\n\n Us, Vs = jnp.meshgrid(u_sparse, v_sparse)\n\n U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200)\n\n UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1)\n dS = snn.vector_field(UVs)\n dUs = dS[:, 0].reshape(Us.shape)\n dVs = dS[:, 1].reshape(Vs.shape)\n\n # Normalize for visualization\n magnitude = np.sqrt(dUs**2 + dVs**2)\n magnitude[magnitude == 0] = 1\n dUs_norm = dUs / magnitude\n dVs_norm = dVs / magnitude\n\n # Nullclines\n ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\")\n ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\")\n\n # Vector field\n ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6)\n\n ax.set_xlabel(\"u (Prey)\")\n ax.set_ylabel(\"v (Predator)\")\n ax.set_title(\"Wererabbit: Phase Portrait\")\n ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\")\n ax.set_xlim(u_range)\n ax.set_ylim(v_range)\n ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\ndef compute_nullclines(snn, u_range, v_range, resolution=200): \"\"\" Compute nullclines du/dt = 0 (u-nullcline) dv/dt = 0 (v-nullcline) \"\"\" u_vals = jnp.linspace(u_range[0], u_range[1], resolution) v_vals = jnp.linspace(v_range[0], v_range[1], resolution) U, V = jnp.meshgrid(u_vals, v_vals) UV = jnp.stack( [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1 ) dS = snn.vector_field(UV) dU = dS[:, 0].reshape(U.shape) dV = dS[:, 1].reshape(V.shape) return U, V, dU, dV def plot_vf(ax, snn, u_range, v_range): import numpy as np u_sparse = jnp.linspace(u_range[0], u_range[1], 20) v_sparse = jnp.linspace(v_range[0], v_range[1], 20) Us, Vs = jnp.meshgrid(u_sparse, v_sparse) U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200) UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1) dS = snn.vector_field(UVs) dUs = dS[:, 0].reshape(Us.shape) dVs = dS[:, 1].reshape(Vs.shape) # Normalize for visualization magnitude = np.sqrt(dUs**2 + dVs**2) magnitude[magnitude == 0] = 1 dUs_norm = dUs / magnitude dVs_norm = dVs / magnitude # Nullclines ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\") ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\") # Vector field ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6) ax.set_xlabel(\"u (Prey)\") ax.set_ylabel(\"v (Predator)\") ax.set_title(\"Wererabbit: Phase Portrait\") ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\") ax.set_xlim(u_range) ax.set_ylim(v_range) ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3) ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3) In\u00a0[14]: Copied!
with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[:, 0, 0], label=\"x1\")\n ax[0].plot(comp_times, state[:, 0, 1], label=\"x2\")\n ax[0].legend([\"x1\", \"x2\"])\n\n plot_vf(ax[1], model, [-0.2, 0.5], [-0.2, 0.5])\n ax[1].plot(state[:, 0, 0], state[:, 0, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[-1, 0, 0], state[-1, 0, 1], \".\", label=\"end\")\n ax[1].legend()\n plt.show()\nwith mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[:, 0, 0], label=\"x1\") ax[0].plot(comp_times, state[:, 0, 1], label=\"x2\") ax[0].legend([\"x1\", \"x2\"]) plot_vf(ax[1], model, [-0.2, 0.5], [-0.2, 0.5]) ax[1].plot(state[:, 0, 0], state[:, 0, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[-1, 0, 0], state[-1, 0, 1], \".\", label=\"end\") ax[1].legend() plt.show() In\u00a0[\u00a0]: Copied!
\n"}]}