mirror of
https://github.com/bics-rug/felice-models.git
synced 2026-03-10 21:14:15 +01:00
1 line
86 KiB
JSON
1 line
86 KiB
JSON
{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"],"fields":{"title":{"boost":1000.0},"text":{"boost":1.0},"tags":{"boost":1000000.0}}},"docs":[{"location":"","title":"Felice","text":"<p>This project provides a JAX implementation of the different neuron models in felice</p>"},{"location":"#overview","title":"Overview","text":"<p>The framework is built on top of EventPropJax and leverages JAX's automatic differentiation for efficient simulation and training of SNNs using event-based exact gradients.</p>"},{"location":"#key-features","title":"Key Features","text":"<ul> <li>Delay learning</li> <li>Non-linear neuron models<ul> <li>WereRabbit Neuron Model: Implementation of a dual-state oscillatory neuron model with bistable dynamics</li> <li>FHN Neuron Model</li> <li>Snowball Neuron Model</li> </ul> </li> </ul>"},{"location":"#installation","title":"\ud83d\udce6 Installation","text":"<p>Felice uses uv for dependency management. To install:</p> <pre><code>uv sync\n</code></pre>"},{"location":"#cuda-support-optional","title":"CUDA Support (Optional)","text":"<p>For GPU acceleration with CUDA 13:</p> <pre><code>uv sync --extra cuda\n</code></pre>"},{"location":"#quick-start","title":"Quick Start","text":"<p>Here's a simple example using the WereRabbit neuron model:</p> <pre><code>import diffrax as dfx\nimport jax.numpy as jnp\nimport jax.random as jrand\nfrom eventpropjax.evnn import FFEvNN\nfrom felice.neuron_models import WereRabbit\n\n# Initialize random key and parameters\nkey = jrand.key(0)\nmax_time = 300e-3\n\n# Create a feedforward event-driven neural network\nsnn = FFEvNN(\n layers=[1],\n in_size=2,\n neuron_model=WereRabbit,\n solver=dfx.Tsit5(),\n max_solver_time=max_time,\n key=key,\n max_event_steps=1000000,\n solver_stepsize=1e-6,\n rtol=10.0,\n atol=0.0,\n Ibias=30e-12,\n)\n\n# Simulate with input spikes\nin_spikes = jnp.asarray([[0.0], [0.157]])\nspikes = snn.spikes_until_t(in_spikes, max_time)\n</code></pre> <p>See the examples directory for more detailed usage examples.</p>"},{"location":"api/","title":"API Reference","text":"<p>API documentation for Felice.</p>"},{"location":"api/#modules","title":"Modules","text":"<ul> <li>Neuron Models - Neuron model implementations</li> <li>Solver - Zero-clipping solver</li> <li>Datasets - Built-in datasets</li> </ul>"},{"location":"api/datasets/","title":"Datasets","text":""},{"location":"api/datasets/#felice.datasets","title":"<code>felice.datasets</code>","text":""},{"location":"api/neuron_models/","title":"Neuron Models","text":""},{"location":"api/neuron_models/#felice.neuron_models","title":"<code>felice.neuron_models</code>","text":""},{"location":"api/neuron_models/#felice.neuron_models-classes","title":"Classes","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang","title":"<code>Boomerang</code>","text":"<p> Bases: <code>Module</code></p> Source code in <code>felice/neuron_models/boomerang.py</code> <pre><code>class Boomerang(eqx.Module):\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n u0: float = eqx.field(static=True)\n v0: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n\n def vector_field(\n self, u: Float[Array, \"...\"], v: Float[Array, \"...\"]\n ) -> Tuple[Float[Array, \"...\"], Float[Array, \"...\"]]:\n alpha = self.alpha\n beta = self.beta\n gamma = self.gamma\n sigma = self.sigma\n rho = self.rho\n\n z = jax.nn.tanh(rho * (v - u))\n du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z\n dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z\n\n return du, dv\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.__init__","title":"<code>__init__(*, atol: float = 1e-06, rtol: float = 0.0001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 30.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)</code>","text":"<p>Initialize the WereRabbit neuron model.</p> <p>Parameters:</p> Name Type Description Default <code>key</code> <p>JAX random key for weight initialization.</p> required <code>n_neurons</code> <p>Number of neurons in this layer.</p> required <code>in_size</code> <p>Number of input connections (excluding recurrent connections).</p> required <code>wmask</code> <p>Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).</p> required <code>rtol</code> <code>float</code> <p>Relative tolerance for the spiking fixpoint calculation.</p> <code>0.0001</code> <code>atol</code> <code>float</code> <p>Absolute tolerance for the spiking fixpoint calculation.</p> <code>1e-06</code> <code>alpha</code> <code>float</code> <p>Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)</p> <code>0.0129</code> <code>beta</code> <code>float</code> <p>Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)</p> <code>15.6</code> <code>gamma</code> <code>float</code> <p>Coupling parameter \\(\\gamma = 26e^{-2}\\)</p> <code>0.26</code> <code>rho</code> <code>float</code> <p>Steepness of the tanh function \\(\\rho\\) (default: 5)</p> <code>30.0</code> <code>sigma</code> <code>float</code> <p>Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)</p> <code>0.6</code> <code>wlim</code> <p>Limit for weight initialization. If None, uses init_weights.</p> required <code>wmean</code> <p>Mean value for weight initialization.</p> required <code>init_weights</code> <p>Optional initial weight values. If None, weights are randomly initialized.</p> required <code>fan_in_mode</code> <p>Mode for fan-in based weight initialization ('sqrt', 'linear').</p> required <code>dtype</code> <code>DTypeLike</code> <p>Data type for arrays (default: float32).</p> <code>float32</code> Source code in <code>felice/neuron_models/boomerang.py</code> <pre><code>def __init__(\n self,\n *,\n atol: float = 1e-6,\n rtol: float = 1e-4,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 30.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n key: JAX random key for weight initialization.\n n_neurons: Number of neurons in this layer.\n in_size: Number of input connections (excluding recurrent connections).\n wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n wlim: Limit for weight initialization. If None, uses init_weights.\n wmean: Mean value for weight initialization.\n init_weights: Optional initial weight values. If None, weights are randomly initialized.\n fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def fn(y, _):\n return self.vector_field(y[0], y[1])\n\n solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)\n y0 = (jnp.array(0.3), jnp.array(0.3))\n u0, v0 = optx.root_find(fn, solver, y0).value\n self.u0 = u0.item()\n self.v0 = v0.item()\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.init_state","title":"<code>init_state(n_neurons: int) -> Float[Array, 'neurons 2']</code>","text":"<p>Initialize the neuron state variables.</p> <p>Parameters:</p> Name Type Description Default <code>n_neurons</code> <code>int</code> <p>Number of neurons to initialize.</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 2']</code> <p>Initial state array of shape (neurons, 3) containing [u, v],</p> <code>Float[Array, 'neurons 2']</code> <p>where u and v are the predator/prey membrane voltages.</p> Source code in <code>felice/neuron_models/boomerang.py</code> <pre><code>def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v],\n where u and v are the predator/prey membrane voltages.\n \"\"\"\n\n u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)\n v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)\n x = jnp.stack([u, v], axis=1)\n return x\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.dynamics","title":"<code>dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']</code>","text":"<p>Compute time derivatives of the neuron state variables.</p> <p>This implements the WereRabbit dynamics</p> <pre><code>- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n</code></pre> <p>Parameters:</p> Name Type Description Default <code>t</code> <code>float</code> <p>Current simulation time (unused but required by framework).</p> required <code>y</code> <code>Float[Array, 'neurons 2']</code> <p>State array of shape (neurons, 2) containing [u, v].</p> required <code>args</code> <code>Dict[str, Any]</code> <p>Additional arguments (unused but required by framework).</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 2']</code> <p>Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].</p> Source code in <code>felice/neuron_models/boomerang.py</code> <pre><code>def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 2) containing [u, v].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n du, dv = self.vector_field(u, v)\n dxdt = jnp.stack([du, dv], axis=1)\n\n return dxdt\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.Boomerang.spike_condition","title":"<code>spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']</code>","text":"<p>Compute spike condition for event detection.</p> <p>A spike is triggered when the system reach to a fixpoint.</p> INFO <p><code>has_spiked</code> is use to the system don't detect a continuos spike when reach a fixpoint.</p> <p>Parameters:</p> Name Type Description Default <code>t</code> <code>float</code> <p>Current simulation time (unused but required by the framework).</p> required <code>y</code> <code>Float[Array, 'neurons 2']</code> <p>State array of shape (neurons, 3) containing [u, v, has_spiked].</p> required <code>**kwargs</code> <code>Dict[str, Any]</code> <p>Additional keyword arguments (unused).</p> <code>{}</code> <p>Returns:</p> Type Description <code>Float[Array, ' neurons']</code> <p>Spike condition array of shape (neurons,). Positive values indicate spike.</p> Source code in <code>felice/neuron_models/boomerang.py</code> <pre><code>def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y) - _norm(vf)\n\n base_cond = calculate_norm(vf, y).repeat(2)\n\n return base_cond\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS","title":"<code>FHNRS</code>","text":"<p> Bases: <code>Module</code></p> <p>FitzHugh-Nagumo neuron model</p> <p>Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast and slow currents to produce oscillatory spiking behavior.</p> <p>The dynamics are governed by:</p> \\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\] <p>where the currents are:</p> <ul> <li>\\(I_{passive} = g_{max}(v - E_{rev})\\)</li> <li>\\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\)</li> <li>\\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)</li> </ul> References <ul> <li>Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.</li> </ul> <p>Attributes:</p> Name Type Description <code>reset_grad_preserve</code> <p>Preserve the gradient when the neuron spikes by doing a soft reset.</p> <code>gmax_pasive</code> <code>float</code> <p>Maximal conductance of the passive current.</p> <code>Erev_pasive</code> <code>float</code> <p>Reversal potential for the passive current.</p> <code>a_fast</code> <code>float</code> <p>Amplitude parameter for the fast current dynamics.</p> <code>voff_fast</code> <code>float</code> <p>Voltage offset for the fast current activation.</p> <code>tau_fast</code> <code>float</code> <p>Time constant for the fast current (typically zero for instantaneous).</p> <code>a_slow</code> <code>float</code> <p>Amplitude parameter for the slow current dynamics.</p> <code>voff_slow</code> <code>float</code> <p>Voltage offset for the slow current activation.</p> <code>tau_slow</code> <code>float</code> <p>Time constant for the slow recovery variable.</p> <code>vthr</code> <code>float</code> <p>Voltage threshold for spike generation.</p> <code>C</code> <code>float</code> <p>Membrane capacitance.</p> <code>tsyn</code> <code>float</code> <p>Synaptic time constant for input current decay.</p> <code>weights</code> <code>float</code> <p>Synaptic weight matrix of shape (in_plus_neurons, neurons).</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>class FHNRS(eqx.Module):\n r\"\"\"FitzHugh-Nagumo neuron model\n\n Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by\n Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast\n and slow currents to produce oscillatory spiking behavior.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\\n \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\\n \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}}\n \\end{align}\n $$\n\n where the currents are:\n\n - $I_{passive} = g_{max}(v - E_{rev})$\n - $I_{fast} = a_{fast} \\tanh(v - v_{off,fast})$\n - $I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})$\n\n References:\n - Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.\n\n Attributes:\n reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.\n gmax_pasive: Maximal conductance of the passive current.\n Erev_pasive: Reversal potential for the passive current.\n a_fast: Amplitude parameter for the fast current dynamics.\n voff_fast: Voltage offset for the fast current activation.\n tau_fast: Time constant for the fast current (typically zero for instantaneous).\n a_slow: Amplitude parameter for the slow current dynamics.\n voff_slow: Voltage offset for the slow current activation.\n tau_slow: Time constant for the slow recovery variable.\n vthr: Voltage threshold for spike generation.\n C: Membrane capacitance.\n tsyn: Synaptic time constant for input current decay.\n weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).\n \"\"\"\n\n # Pasive parameters\n gmax_pasive: float = eqx.field(static=True)\n Erev_pasive: float = eqx.field(static=True)\n\n # Fast current\n a_fast: float = eqx.field(static=True)\n voff_fast: float = eqx.field(static=True)\n tau_fast: float = eqx.field(static=True)\n\n # Slow current\n a_slow: float = eqx.field(static=True)\n voff_slow: float = eqx.field(static=True)\n tau_slow: float = eqx.field(static=True)\n\n # Neuron threshold\n vthr: float = eqx.field(static=True)\n C: float = eqx.field(static=True, default=1.0)\n\n # Input synaptic time constant\n tsyn: float = eqx.field(static=True)\n\n dtype: DTypeLike = eqx.field(static=True)\n\n def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n ):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n\n def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.__init__","title":"<code>__init__(*, tsyn: Union[int, float, jnp.ndarray] = 1.0, C: Union[int, float, jnp.ndarray] = 1.0, gmax_pasive: Union[int, float, jnp.ndarray] = 1.0, Erev_pasive: Union[int, float, jnp.ndarray] = 0.0, a_fast: Union[int, float, jnp.ndarray] = -2.0, voff_fast: Union[int, float, jnp.ndarray] = 0.0, tau_fast: Union[int, float, jnp.ndarray] = 0.0, a_slow: Union[int, float, jnp.ndarray] = 2.0, voff_slow: Union[int, float, jnp.ndarray] = 0.0, tau_slow: Union[int, float, jnp.ndarray] = 50.0, vthr: Union[int, float, jnp.ndarray] = 2.0, dtype: DTypeLike = jnp.float32)</code>","text":"<p>Initialize the FitzHugh-Nagumo neuron model.</p> <p>Parameters:</p> Name Type Description Default <code>tsyn</code> <code>Union[int, float, ndarray]</code> <p>Synaptic time constant for input current decay. Can be scalar or per-neuron array.</p> <code>1.0</code> <code>C</code> <code>Union[int, float, ndarray]</code> <p>Membrane capacitance. Can be scalar or per-neuron array.</p> <code>1.0</code> <code>gmax_pasive</code> <code>Union[int, float, ndarray]</code> <p>Maximal conductance of passive current. Can be scalar or per-neuron array.</p> <code>1.0</code> <code>Erev_pasive</code> <code>Union[int, float, ndarray]</code> <p>Reversal potential for passive current. Can be scalar or per-neuron array.</p> <code>0.0</code> <code>a_fast</code> <code>Union[int, float, ndarray]</code> <p>Amplitude of fast current. Can be scalar or per-neuron array.</p> <code>-2.0</code> <code>voff_fast</code> <code>Union[int, float, ndarray]</code> <p>Voltage offset for fast current activation. Can be scalar or per-neuron array.</p> <code>0.0</code> <code>tau_fast</code> <code>Union[int, float, ndarray]</code> <p>Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.</p> <code>0.0</code> <code>a_slow</code> <code>Union[int, float, ndarray]</code> <p>Amplitude of slow current. Can be scalar or per-neuron array.</p> <code>2.0</code> <code>voff_slow</code> <code>Union[int, float, ndarray]</code> <p>Voltage offset for slow current activation. Can be scalar or per-neuron array.</p> <code>0.0</code> <code>tau_slow</code> <code>Union[int, float, ndarray]</code> <p>Time constant for slow recovery variable. Can be scalar or per-neuron array.</p> <code>50.0</code> <code>vthr</code> <code>Union[int, float, ndarray]</code> <p>Voltage threshold for spike generation. Can be scalar or per-neuron array.</p> <code>2.0</code> <code>dtype</code> <code>DTypeLike</code> <p>Data type for arrays (default: float32).</p> <code>float32</code> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def __init__(\n self,\n *,\n tsyn: Union[int, float, jnp.ndarray] = 1.0,\n C: Union[int, float, jnp.ndarray] = 1.0,\n gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,\n Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,\n a_fast: Union[int, float, jnp.ndarray] = -2.0,\n voff_fast: Union[int, float, jnp.ndarray] = 0.0,\n tau_fast: Union[int, float, jnp.ndarray] = 0.0,\n a_slow: Union[int, float, jnp.ndarray] = 2.0,\n voff_slow: Union[int, float, jnp.ndarray] = 0.0,\n tau_slow: Union[int, float, jnp.ndarray] = 50.0,\n vthr: Union[int, float, jnp.ndarray] = 2.0,\n dtype: DTypeLike = jnp.float32,\n):\n \"\"\"Initialize the FitzHugh-Nagumo neuron model.\n\n Args:\n tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.\n C: Membrane capacitance. Can be scalar or per-neuron array.\n gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.\n Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.\n a_fast: Amplitude of fast current. Can be scalar or per-neuron array.\n voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.\n tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.\n a_slow: Amplitude of slow current. Can be scalar or per-neuron array.\n voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.\n tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.\n vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n\n self.tsyn = tsyn\n self.C = C\n self.gmax_pasive = gmax_pasive\n self.Erev_pasive = Erev_pasive\n self.a_fast = a_fast\n self.voff_fast = voff_fast\n self.tau_fast = tau_fast\n self.a_slow = a_slow\n self.voff_slow = voff_slow\n self.tau_slow = tau_slow\n self.vthr = vthr\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.init_state","title":"<code>init_state(n_neurons: int) -> Float[Array, 'neurons 3']</code>","text":"<p>Initialize the neuron state variables.</p> <p>Parameters:</p> Name Type Description Default <code>n_neurons</code> <code>int</code> <p>Number of neurons to initialize.</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 3']</code> <p>Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],</p> <code>Float[Array, 'neurons 3']</code> <p>where v is membrane voltage, v_slow is the slow recovery variable,</p> <code>Float[Array, 'neurons 3']</code> <p>and i_app is the applied synaptic current.</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def init_state(self, n_neurons: int) -> Float[Array, \"neurons 3\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],\n where v is membrane voltage, v_slow is the slow recovery variable,\n and i_app is the applied synaptic current.\n \"\"\"\n return jnp.zeros((n_neurons, 3), dtype=self.dtype)\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_inst","title":"<code>IV_inst(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]</code>","text":"<p>Compute instantaneous I-V relationship with fast and slow currents at rest.</p> <p>Parameters:</p> Name Type Description Default <code>v</code> <code>Float[Array, ...]</code> <p>Membrane voltage.</p> required <code>Vrest</code> <code>float</code> <p>Resting voltage for both fast and slow currents (default: 0).</p> <code>0</code> <p>Returns:</p> Type Description <code>Float[Array, ...]</code> <p>Total current at voltage v with both fast and slow currents evaluated at Vrest.</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def IV_inst(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute instantaneous I-V relationship with fast and slow currents at rest.\n\n Args:\n v: Membrane voltage.\n Vrest: Resting voltage for both fast and slow currents (default: 0).\n\n Returns:\n Total current at voltage v with both fast and slow currents evaluated at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_fast","title":"<code>IV_fast(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]</code>","text":"<p>Compute I-V relationship with fast current at voltage v and slow current at rest.</p> <p>Parameters:</p> Name Type Description Default <code>v</code> <code>Float[Array, ...]</code> <p>Membrane voltage for passive and fast currents.</p> required <code>Vrest</code> <code>float</code> <p>Resting voltage for slow current (default: 0).</p> <code>0</code> <p>Returns:</p> Type Description <code>Float[Array, ...]</code> <p>Total current with fast dynamics responding to v and slow current at Vrest.</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def IV_fast(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute I-V relationship with fast current at voltage v and slow current at rest.\n\n Args:\n v: Membrane voltage for passive and fast currents.\n Vrest: Resting voltage for slow current (default: 0).\n\n Returns:\n Total current with fast dynamics responding to v and slow current at Vrest.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.IV_slow","title":"<code>IV_slow(v: Float[Array, ...], Vrest: float = 0) -> Float[Array, ...]</code>","text":"<p>Compute steady-state I-V relationship with all currents at voltage v.</p> <p>Parameters:</p> Name Type Description Default <code>v</code> <code>Float[Array, ...]</code> <p>Membrane voltage for all currents.</p> required <code>Vrest</code> <code>float</code> <p>Unused parameter for API consistency (default: 0).</p> <code>0</code> <p>Returns:</p> Type Description <code>Float[Array, ...]</code> <p>Total steady-state current with all currents responding to v.</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def IV_slow(self, v: Float[Array, \"...\"], Vrest: float = 0) -> Float[Array, \"...\"]:\n \"\"\"Compute steady-state I-V relationship with all currents at voltage v.\n\n Args:\n v: Membrane voltage for all currents.\n Vrest: Unused parameter for API consistency (default: 0).\n\n Returns:\n Total steady-state current with all currents responding to v.\n \"\"\"\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)\n\n return I_pasive + I_fast + I_slow\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.dynamics","title":"<code>dynamics(t: float, y: Float[Array, 'neurons 3'], args: Dict[str, Any]) -> Float[Array, 'neurons 3']</code>","text":"<p>Compute time derivatives of the neuron state variables.</p> <p>This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents: - dv/dt: Fast membrane voltage dynamics - dv_slow/dt: Slow recovery variable dynamics - di_app/dt: Synaptic current decay</p> <p>Parameters:</p> Name Type Description Default <code>t</code> <code>float</code> <p>Current simulation time (unused but required by framework).</p> required <code>y</code> <code>Float[Array, 'neurons 3']</code> <p>State array of shape (neurons, 3) containing [v, v_slow, i_app].</p> required <code>args</code> <code>Dict[str, Any]</code> <p>Additional arguments (unused but required by framework).</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 3']</code> <p>Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 3\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:\n - dv/dt: Fast membrane voltage dynamics\n - dv_slow/dt: Slow recovery variable dynamics\n - di_app/dt: Synaptic current decay\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].\n \"\"\"\n v = y[:, 0]\n v_slow = y[:, 1]\n i_app = y[:, 2]\n\n I_pasive = self.gmax_pasive * (v - self.Erev_pasive)\n I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)\n I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)\n\n i_sum = I_pasive + I_fast + I_slow\n\n dv_dt = (i_app - i_sum) / self.C\n dvslow_dt = (v - v_slow) / self.tau_slow\n di_dt = -i_app / self.tsyn\n\n return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.FHNRS.spike_condition","title":"<code>spike_condition(t: float, y: Float[Array, 'neurons 3'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']</code>","text":"<p>Compute spike condition for event detection.</p> <p>A spike is triggered when this function crosses zero (v >= vthr).</p> <p>Parameters:</p> Name Type Description Default <code>t</code> <code>float</code> <p>Current simulation time (unused but required by event detection).</p> required <code>y</code> <code>Float[Array, 'neurons 3']</code> <p>State array of shape (neurons, 3) containing [v, v_slow, i_app].</p> required <code>**kwargs</code> <code>Dict[str, Any]</code> <p>Additional keyword arguments (unused).</p> <code>{}</code> <p>Returns:</p> Type Description <code>Float[Array, ' neurons']</code> <p>Spike condition array of shape (neurons,). Positive values indicate v > vthr.</p> Source code in <code>felice/neuron_models/fhn.py</code> <pre><code>def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 3\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when this function crosses zero (v >= vthr).\n\n Args:\n t: Current simulation time (unused but required by event detection).\n y: State array of shape (neurons, 3) containing [v, v_slow, i_app].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate v > vthr.\n \"\"\"\n return y[:, 0] - self.vthr\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit","title":"<code>WereRabbit</code>","text":"<p> Bases: <code>Module</code></p> <p>WereRabbit Neuron Model</p> <p>The WereRabbit model implements a predator-prey dynamic with bistable switching behavior controlled by a \"moon phase\" parameter \\(z\\).</p> <p>The dynamics are governed by:</p> \\[ \\begin{align} z &= tanh(\\rho (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\] <p>where \\(z\\) represents the \"moon phase\" that switches the predator-prey roles.</p> <p>Attributes:</p> Name Type Description <code>alpha</code> <code>float</code> <p>Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)</p> <code>beta</code> <code>float</code> <p>Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)</p> <code>gamma</code> <code>float</code> <p>Coupling parameter \\(\\gamma = 26e^{-2}\\)</p> <code>rho</code> <code>float</code> <p>Steepness of the tanh function \\(\\rho\\) (default: 5)</p> <code>sigma</code> <code>float</code> <p>Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)</p> <code>rtol</code> <code>float</code> <p>Relative tolerance for the spiking fixpoint calculation.</p> <code>atol</code> <code>float</code> <p>Absolute tolerance for the spiking fixpoint calculation.</p> <code>weight_u</code> <code>float</code> <p>Input weight for the predator.</p> <code>weight_v</code> <code>float</code> <p>Input weight for the prey.</p> Source code in <code>felice/neuron_models/wererabbit.py</code> <pre><code>class WereRabbit(eqx.Module):\n r\"\"\"\n WereRabbit Neuron Model\n\n The WereRabbit model implements a predator-prey dynamic with bistable \n switching behavior controlled by a \"moon phase\" parameter $z$.\n\n The dynamics are governed by:\n\n $$\n \\begin{align}\n z &= tanh(\\rho (u-v)) \\\\\n \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\\n \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma\n \\end{align}\n $$\n\n where $z$ represents the \"moon phase\" that switches the predator-prey roles.\n\n Attributes:\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n\n weight_u: Input weight for the predator.\n weight_v: Input weight for the prey.\n \"\"\"\n\n dtype: DTypeLike = eqx.field(static=True)\n rtol: float = eqx.field(static=True)\n atol: float = eqx.field(static=True)\n\n alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio\n beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)\n gamma: float = eqx.field(static=True) # coupling coefficient\n rho: float = eqx.field(static=True) # tanh steepness\n sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)\n\n def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n ):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n\n def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n\n def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n\n def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n ) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n\n def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n ) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit-functions","title":"Functions","text":""},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.__init__","title":"<code>__init__(*, atol: float = 0.001, rtol: float = 0.001, alpha: float = 0.0129, beta: float = 15.6, gamma: float = 0.26, rho: float = 5.0, sigma: float = 0.6, dtype: DTypeLike = jnp.float32)</code>","text":"<p>Initialize the WereRabbit neuron model.</p> <p>Parameters:</p> Name Type Description Default <code>rtol</code> <code>float</code> <p>Relative tolerance for the spiking fixpoint calculation.</p> <code>0.001</code> <code>atol</code> <code>float</code> <p>Absolute tolerance for the spiking fixpoint calculation.</p> <code>0.001</code> <code>alpha</code> <code>float</code> <p>Current scaling parameter \\(\\alpha = I_{n0}/I_{bias}\\) (default: 0.0129)</p> <code>0.0129</code> <code>beta</code> <code>float</code> <p>Exponential slope \\(\\beta = \\kappa/U_t\\) (default: 15.6)</p> <code>15.6</code> <code>gamma</code> <code>float</code> <p>Coupling parameter \\(\\gamma = 26e^{-2}\\)</p> <code>0.26</code> <code>rho</code> <code>float</code> <p>Steepness of the tanh function \\(\\rho\\) (default: 5)</p> <code>5.0</code> <code>sigma</code> <code>float</code> <p>Fixpoint distance scaling \\(\\sigma\\) (default: 0.6)</p> <code>0.6</code> <code>dtype</code> <code>DTypeLike</code> <p>Data type for arrays (default: float32).</p> <code>float32</code> Source code in <code>felice/neuron_models/wererabbit.py</code> <pre><code>def __init__(\n self,\n *,\n atol: float = 1e-3,\n rtol: float = 1e-3,\n alpha: float = 0.0129,\n beta: float = 15.6,\n gamma: float = 0.26,\n rho: float = 5.0,\n sigma: float = 0.6,\n dtype: DTypeLike = jnp.float32,\n):\n r\"\"\"Initialize the WereRabbit neuron model.\n\n Args:\n rtol: Relative tolerance for the spiking fixpoint calculation.\n atol: Absolute tolerance for the spiking fixpoint calculation.\n alpha: Current scaling parameter $\\alpha = I_{n0}/I_{bias}$ (default: 0.0129)\n beta: Exponential slope $\\beta = \\kappa/U_t$ (default: 15.6)\n gamma: Coupling parameter $\\gamma = 26e^{-2}$\n rho: Steepness of the tanh function $\\rho$ (default: 5)\n sigma: Fixpoint distance scaling $\\sigma$ (default: 0.6)\n dtype: Data type for arrays (default: float32).\n \"\"\"\n self.dtype = dtype\n self.alpha = alpha\n self.beta = beta\n self.gamma = gamma\n self.rho = rho\n self.sigma = sigma\n\n self.rtol = rtol\n self.atol = atol\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.init_state","title":"<code>init_state(n_neurons: int) -> Float[Array, 'neurons 2']</code>","text":"<p>Initialize the neuron state variables.</p> <p>Parameters:</p> Name Type Description Default <code>n_neurons</code> <code>int</code> <p>Number of neurons to initialize.</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 2']</code> <p>Initial state array of shape (neurons, 3) containing [u, v, has_spiked],</p> <code>Float[Array, 'neurons 2']</code> <p>where u and v are the predator/prey membrane voltages, has_spiked is a</p> <code>Float[Array, 'neurons 2']</code> <p>variable that is 1 whenever the neuron spike and 0 otherwise .</p> Source code in <code>felice/neuron_models/wererabbit.py</code> <pre><code>def init_state(self, n_neurons: int) -> Float[Array, \"neurons 2\"]:\n \"\"\"Initialize the neuron state variables.\n\n Args:\n n_neurons: Number of neurons to initialize.\n\n Returns:\n Initial state array of shape (neurons, 3) containing [u, v, has_spiked],\n where u and v are the predator/prey membrane voltages, has_spiked is a\n variable that is 1 whenever the neuron spike and 0 otherwise .\n \"\"\"\n x1 = jnp.zeros((n_neurons,), dtype=self.dtype)\n x2 = jnp.zeros((n_neurons,), dtype=self.dtype)\n return jnp.stack([x1, x2], axis=1)\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.vector_field","title":"<code>vector_field(y: Float[Array, 'neurons 2']) -> Float[Array, 'neurons 2']</code>","text":"<p>Compute vector field of the neuron state variables.</p> <p>This implements the WereRabbit dynamics</p> <pre><code>- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n</code></pre> <p>Parameters:</p> Name Type Description Default <code>y</code> <code>Float[Array, 'neurons 2']</code> <p>State array of shape (neurons, 2) containing [u, v].</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 2']</code> <p>Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].</p> Source code in <code>felice/neuron_models/wererabbit.py</code> <pre><code>def vector_field(self, y: Float[Array, \"neurons 2\"]) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute vector field of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n y: State array of shape (neurons, 2) containing [u, v].\n\n Returns:\n Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].\n \"\"\"\n u = y[:, 0]\n v = y[:, 1]\n\n z = jax.nn.tanh(self.rho * (u - v))\n du = (\n z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))\n - self.sigma\n )\n dv = (\n z\n * (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))\n - self.sigma\n )\n\n dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)\n du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)\n\n return jnp.stack([du, dv], axis=1)\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.dynamics","title":"<code>dynamics(t: float, y: Float[Array, 'neurons 2'], args: Dict[str, Any]) -> Float[Array, 'neurons 2']</code>","text":"<p>Compute time derivatives of the neuron state variables.</p> <p>This implements the WereRabbit dynamics</p> <pre><code>- du/dt: Predator dynamics\n- dv/dt: WerePrey dynamics\n</code></pre> <p>Parameters:</p> Name Type Description Default <code>t</code> <code>float</code> <p>Current simulation time (unused but required by framework).</p> required <code>y</code> <code>Float[Array, 'neurons 2']</code> <p>State array of shape (neurons, 3) containing [u, v, has_spiked].</p> required <code>args</code> <code>Dict[str, Any]</code> <p>Additional arguments (unused but required by framework).</p> required <p>Returns:</p> Type Description <code>Float[Array, 'neurons 2']</code> <p>Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].</p> Source code in <code>felice/neuron_models/wererabbit.py</code> <pre><code>def dynamics(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n args: Dict[str, Any],\n) -> Float[Array, \"neurons 2\"]:\n \"\"\"Compute time derivatives of the neuron state variables.\n\n This implements the WereRabbit dynamics\n\n - du/dt: Predator dynamics\n - dv/dt: WerePrey dynamics\n\n Args:\n t: Current simulation time (unused but required by framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n args: Additional arguments (unused but required by framework).\n\n Returns:\n Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].\n \"\"\"\n dxdt = self.vector_field(y)\n\n return dxdt\n</code></pre>"},{"location":"api/neuron_models/#felice.neuron_models.WereRabbit.spike_condition","title":"<code>spike_condition(t: float, y: Float[Array, 'neurons 2'], **kwargs: Dict[str, Any]) -> Float[Array, ' neurons']</code>","text":"<p>Compute spike condition for event detection.</p> <p>A spike is triggered when the system reach to a fixpoint.</p> INFO <p><code>has_spiked</code> is use to the system don't detect a continuos spike when reach a fixpoint.</p> <p>Parameters:</p> Name Type Description Default <code>t</code> <code>float</code> <p>Current simulation time (unused but required by the framework).</p> required <code>y</code> <code>Float[Array, 'neurons 2']</code> <p>State array of shape (neurons, 3) containing [u, v, has_spiked].</p> required <code>**kwargs</code> <code>Dict[str, Any]</code> <p>Additional keyword arguments (unused).</p> <code>{}</code> <p>Returns:</p> Type Description <code>Float[Array, ' neurons']</code> <p>Spike condition array of shape (neurons,). Positive values indicate spike.</p> Source code in <code>felice/neuron_models/wererabbit.py</code> <pre><code>def spike_condition(\n self,\n t: float,\n y: Float[Array, \"neurons 2\"],\n **kwargs: Dict[str, Any],\n) -> Float[Array, \" neurons\"]:\n \"\"\"Compute spike condition for event detection.\n\n A spike is triggered when the system reach to a fixpoint.\n\n INFO:\n `has_spiked` is use to the system don't detect a continuos\n spike when reach a fixpoint.\n\n Args:\n t: Current simulation time (unused but required by the framework).\n y: State array of shape (neurons, 3) containing [u, v, has_spiked].\n **kwargs: Additional keyword arguments (unused).\n\n Returns:\n Spike condition array of shape (neurons,). Positive values indicate spike.\n \"\"\"\n _atol = self.atol\n _rtol = self.rtol\n _norm = optx.rms_norm\n\n vf = self.dynamics(t, y, {})\n\n @jax.vmap\n def calculate_norm(vf, y):\n return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])\n\n base_cond = calculate_norm(vf, y)\n\n return base_cond\n</code></pre>"},{"location":"api/solver/","title":"Solver","text":""},{"location":"api/solver/#felice.solver","title":"<code>felice.solver</code>","text":""},{"location":"api/solver/#felice.solver-classes","title":"Classes","text":""},{"location":"api/solver/#felice.solver.ClipSolver","title":"<code>ClipSolver</code>","text":"<p> Bases: <code>Module</code></p> Source code in <code>felice/solver.py</code> <pre><code>class ClipSolver(eqx.Module):\n solver: AbstractSolver\n\n def __getattr__(self, name):\n return getattr(self.solver, name)\n\n def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n</code></pre>"},{"location":"api/solver/#felice.solver.ClipSolver-functions","title":"Functions","text":""},{"location":"api/solver/#felice.solver.ClipSolver.step","title":"<code>step(terms: PyTree[AbstractTerm], t0: RealScalarLike, t1: RealScalarLike, y0: Y, args: Args, solver_state: _SolverState, made_jump: BoolScalarLike) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]</code>","text":"<p>Make a single step of the solver.</p> <p>Each step is made over the specified interval \\([t_0, t_1]\\).</p> <p>Arguments:</p> <ul> <li><code>terms</code>: The PyTree of terms representing the vector fields and controls.</li> <li><code>t0</code>: The start of the interval that the step is made over.</li> <li><code>t1</code>: The end of the interval that the step is made over.</li> <li><code>y0</code>: The current value of the solution at <code>t0</code>.</li> <li><code>args</code>: Any extra arguments passed to the vector field.</li> <li><code>solver_state</code>: Any evolving state for the solver itself, at <code>t0</code>.</li> <li><code>made_jump</code>: Whether there was a discontinuity in the vector field at <code>t0</code>. Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there are no jumps and for efficiency re-use information between steps; this indicates that a jump has just occurred and this assumption is not true.</li> </ul> <p>Returns:</p> <p>A tuple of several objects:</p> <ul> <li>The value of the solution at <code>t1</code>.</li> <li>A local error estimate made during the step. (Used by adaptive step size controllers to change the step size.) May be <code>None</code> if no estimate was made.</li> <li>Some dictionary of information that is passed to the solver's interpolation routine to calculate dense output. (Used with <code>SaveAt(ts=...)</code> or <code>SaveAt(dense=...)</code>.)</li> <li>The value of the solver state at <code>t1</code>.</li> <li>An integer (corresponding to <code>diffrax.RESULTS</code>) indicating whether the step happened successfully, or if (unusually) it failed for some reason.</li> </ul> Source code in <code>felice/solver.py</code> <pre><code>def step(\n self,\n terms: PyTree[AbstractTerm],\n t0: RealScalarLike,\n t1: RealScalarLike,\n y0: Y,\n args: Args,\n solver_state: _SolverState,\n made_jump: BoolScalarLike,\n) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:\n \"\"\"Make a single step of the solver.\n\n Each step is made over the specified interval $[t_0, t_1]$.\n\n **Arguments:**\n\n - `terms`: The PyTree of terms representing the vector fields and controls.\n - `t0`: The start of the interval that the step is made over.\n - `t1`: The end of the interval that the step is made over.\n - `y0`: The current value of the solution at `t0`.\n - `args`: Any extra arguments passed to the vector field.\n - `solver_state`: Any evolving state for the solver itself, at `t0`.\n - `made_jump`: Whether there was a discontinuity in the vector field at `t0`.\n Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there\n are no jumps and for efficiency re-use information between steps; this\n indicates that a jump has just occurred and this assumption is not true.\n\n **Returns:**\n\n A tuple of several objects:\n\n - The value of the solution at `t1`.\n - A local error estimate made during the step. (Used by adaptive step size\n controllers to change the step size.) May be `None` if no estimate was\n made.\n - Some dictionary of information that is passed to the solver's interpolation\n routine to calculate dense output. (Used with `SaveAt(ts=...)` or\n `SaveAt(dense=...)`.)\n - The value of the solver state at `t1`.\n - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step\n happened successfully, or if (unusually) it failed for some reason.\n \"\"\"\n y1, y_error, dense_info, solver_state, result = self.solver.step(\n terms, t0, t1, y0, args, solver_state, made_jump\n )\n y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)\n return y1_clipped, y_error, dense_info, solver_state, result\n</code></pre>"},{"location":"neuron_models/","title":"Neuron Models","text":"<p>Felice implements several non-linear neuron models for spiking neural networks.</p>"},{"location":"neuron_models/#available-models","title":"Available Models","text":"Model Type Key Features WereRabbit Dual-state oscillatory Bistable dynamics, predator-prey FitzHugh-Nagumo ... ... Snowball Exponential Integrate-and-Fire neuron model ..."},{"location":"neuron_models/fhn/","title":"FitzHugh-Nagumo","text":""},{"location":"neuron_models/fhn/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\\\ \\frac{dv_{slow}}{dt} &= \\frac{v - v_{slow}}{\\tau_{slow}} \\\\ \\frac{dI_{app}}{dt} &= -\\frac{I_{app}}{\\tau_{syn}} \\end{align} \\] <p>where the currents are: - \\(I_{passive} = g_{max}(v - E_{rev})\\) - \\(I_{fast} = a_{fast} \\tanh(v - v_{off,fast})\\) - \\(I_{slow} = a_{slow} \\tanh(v_{slow} - v_{off,slow})\\)</p>"},{"location":"neuron_models/fhn/#examples","title":"Examples","text":"<p>See the following interactive notebook for a practical example:</p> <ul> <li>Basic Usage Example - Introduction to the FitzHugh-Nagumo model</li> </ul>"},{"location":"neuron_models/fhn/fhn/","title":"Example","text":"In\u00a0[\u00a0]: Copied! <pre>import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nfrom eventpropjax.evnn import FFEvNN\n\nfrom felice.neuron_models import FHN\n</pre> import diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from eventpropjax.evnn import FFEvNN from felice.neuron_models import FHN In\u00a0[2]: Copied! <pre>key = jrand.key(0)\nmax_time = 200\n\nsnn = FFEvNN(\n layers=[1],\n in_size=1,\n neuron_model=FHN,\n solver=dfx.Dopri5(),\n max_solver_time=max_time,\n key=key,\n max_event_steps=1000000,\n solver_stepsize=0.1,\n init_weights=2.0,\n)\n</pre> key = jrand.key(0) max_time = 200 snn = FFEvNN( layers=[1], in_size=1, neuron_model=FHN, solver=dfx.Dopri5(), max_solver_time=max_time, key=key, max_event_steps=1000000, solver_stepsize=0.1, init_weights=2.0, ) In\u00a0[3]: Copied! <pre>v_range = jnp.arange(-3.1, 3, 0.1)\nVI_inst = jax.vmap(snn.neuron_model.IV_inst)(v_range)\nVI_fast = jax.vmap(snn.neuron_model.IV_fast)(v_range)\nVI_slow = jax.vmap(snn.neuron_model.IV_slow)(v_range)\n\nwith mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True)\n ax[0].plot(v_range, VI_inst)\n ax[1].plot(v_range, VI_fast)\n ax[2].plot(v_range, VI_slow)\n plt.show()\n</pre> v_range = jnp.arange(-3.1, 3, 0.1) VI_inst = jax.vmap(snn.neuron_model.IV_inst)(v_range) VI_fast = jax.vmap(snn.neuron_model.IV_fast)(v_range) VI_slow = jax.vmap(snn.neuron_model.IV_slow)(v_range) with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 3, figsize=(6.9, 2.3), dpi=200.0, sharey=True) ax[0].plot(v_range, VI_inst) ax[1].plot(v_range, VI_fast) ax[2].plot(v_range, VI_slow) plt.show() In\u00a0[4]: Copied! <pre>in_spikes = jnp.asarray([[0.00]])\ncomp_times = jnp.linspace(0.0, max_time, 500)\nstate = snn.state_at_t(in_spikes, comp_times)\nspikes = snn.spikes_until_t(in_spikes, max_time)\n</pre> in_spikes = jnp.asarray([[0.00]]) comp_times = jnp.linspace(0.0, max_time, 500) state = snn.state_at_t(in_spikes, comp_times) spikes = snn.spikes_until_t(in_spikes, max_time) In\u00a0[5]: Copied! <pre>with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[0, :, 0])\n ax[0].plot(comp_times, state[0, :, 1], \"--\")\n # ax[0].plot(comp_times, state[0, :, 2], \"-.\")\n [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)]\n ax[0].set_xlabel(\"Time (ms)\")\n ax[0].legend([\"v\", \"vslow\", \"syn\", \"Spike\"])\n\n ax[1].plot(state[0, :, 0], state[0, :, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\")\n ax[1].set_xlabel(\"v\")\n ax[1].set_ylabel(\"v fast\")\n ax[1].legend()\n plt.show()\n</pre> with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[0, :, 0]) ax[0].plot(comp_times, state[0, :, 1], \"--\") # ax[0].plot(comp_times, state[0, :, 2], \"-.\") [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)] ax[0].set_xlabel(\"Time (ms)\") ax[0].legend([\"v\", \"vslow\", \"syn\", \"Spike\"]) ax[1].plot(state[0, :, 0], state[0, :, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\") ax[1].set_xlabel(\"v\") ax[1].set_ylabel(\"v fast\") ax[1].legend() plt.show() In\u00a0[\u00a0]: Copied! <pre>\n</pre>"},{"location":"neuron_models/snowball/","title":"Snowball","text":""},{"location":"neuron_models/snowball/#circuit-description","title":"Circuit description","text":"<p>The circuit implemented for exponential integrate and fire neuron has been used from [1]. Part (a) in Fig.2 in [1] implements the exponential integrate and fire neuron. The neuron receives input currents using the input DPI filter [2]. This input current is integrated on the node Vmem by the membrane capacitance. The membrane potential leaks in the absence of an input spike which can be set by the bias Vleak. The Vmem potential node is connected to a cascoded source follower formed by the P14-15 and N5-6. A threshold voltage of the neuron can be set by the bias Vthr which is compared to the membrane potential. When the membrane potential is just near the threshold voltage, it starts the positive feedback block which exponentially increases membrane potential and causes the neuron to spike. As the neuron spikes, the membrane potential gets reset to ground and the refractory bias helps to stop the neuron from spiking during the refractory period as similar to a biological neuron. The circuit implemented for this experiment does not exercise either adaptability or needs a pulse extender as implemented in [1]. The Vdd used in the simulation is 1V. The neuron receives 5nA input pulses with a pulse width of 100\u03bcs.</p> <p>Input current mirror W/l = 0.2 All other transistors W/L = 4/3</p>"},{"location":"neuron_models/snowball/#circuit-simulation","title":"Circuit Simulation","text":"<p> Fig.1 The dynamics of Exponential integrate and fire neuron. The light blue signal is the input spikes, the yellow signal is the membrane potential and the dark blue is the output spikes from the neuron.</p>"},{"location":"neuron_models/snowball/#references","title":"References","text":"<ol> <li>Rubino, Arianna, Melika Payvand, and Giacomo Indiveri. \"Ultra-low power silicon neuron circuit for extreme-edge neuromorphic intelligence.\" 2019 26th IEEE International Conference on Electronics, Circuits and Systems (ICECS). IEEE, 2019.</li> <li>Bartolozzi, Chiara, Srinjoy Mitra, and Giacomo Indiveri. \"An ultra low power current-mode filter for neuromorphic systems and biomedical signal processing.\" 2006 IEEE Biomedical Circuits and Systems Conference. IEEE, 2006.</li> </ol>"},{"location":"neuron_models/wererabbit/","title":"WereRabbit","text":"<p>The wererabbit neuron model is a two coupled oscillator that follows a predator- prey dynamic with a switching in the diagonal of the phaseplane. When the z in equation 1c represents the \u201cmoon phase\u201d, when ever it cross that threshold, the rabbit (prey) becomes the predator.</p>"},{"location":"neuron_models/wererabbit/#circuit-equation","title":"Circuit equation","text":"\\[ \\begin{align} C\\frac{du}{dt} &= z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - I_a \\\\ C\\frac{dv}{dt} &= -z I_{bias} + I_{n0} e^{\\kappa u / U_t} [z + 26e^{-2} (0.5 - v) z] - I_a \\\\ z &= tanh(\\rho (u-v))\\\\ I_a &= \\sigma I_{bias} \\\\ \\end{align} \\] Parameter Symbol Definition Value Capacitance C Circuit capacitance \\(0.1\\,pF\\) Bias current \\(I_{bias}\\) DC bias current for the fixpoint location \\(100\\,pA\\) Leakage current \\(I_{n0}\\) Transistor leakage current \\(0.129\\,pA\\) Subthreshold slope \\(\\kappa\\) Transistor subthreshold slope factor \\(0.39\\) Thermal voltage \\(U_t\\) Thermal voltage at room temperature \\(25\\,mV\\) Bias scale \\(\\sigma\\) Scaling factor for the distance between fixpoints \\(0.6\\) Steepness \\(\\rho\\) Tanh steepness for the moonphase \\(5\\)s"},{"location":"neuron_models/wererabbit/#abstraction","title":"Abstraction","text":"<p>To simplify the analysis of the model for simulation purposes, we can introduce a dimensionless time variable \\(\\tau=tI_{bias}/C\\), transforming the derivate of the equations in \\(\\frac{d}{dt}=\\frac{I_{bias}}{C}\\frac{d}{d\\tau}\\). Substituting this time transformation on equation~\\ref{eq:wererabbit:circ}</p> \\[ \\begin{equation} C\\frac{I_{bias}}{C}\\frac{du}{d\\tau} = z I_{bias} - I_{n0} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma I_{bias} \\end{equation} \\] <p>And dividing by \\(I_{bias}\\) on both sides:</p> \\[ \\begin{equation} \\frac{du}{d\\tau} = z - \\frac{I_{n0}}{I_{bias}} e^{\\kappa v / U_t} [z + 26e^{-2} (0.5 - u) z] - \\sigma \\end{equation} \\] <p>Obtaining the following set of equations:</p> \\[ \\begin{align} z &= tanh(\\kappa (u-v)) \\\\ \\frac{du}{dt} &= z - z \\alpha e^{\\beta v} [1 + \\gamma (0.5 - u)] - \\sigma \\\\ \\frac{dv}{dt} &= -z - z \\alpha e^{\\beta u} [1 + \\gamma (0.5 - v)] - \\sigma \\end{align} \\] Parameter Definition Value \\(\\tau\\) \\(tI_{bias}/C\\) -- \\(\\alpha\\) \\(I_{n0}/I_{bias}\\) \\(0.0129\\) \\(\\beta\\) \\(\\kappa/U_t\\) 15.6 \\(\\gamma\\) -- \\(26e^{-2}\\) \\(\\rho\\) Tanh steepness for the moonphase 5 \\(\\sigma\\) Scaling factor for the distance between fixpoints 0.6"},{"location":"neuron_models/wererabbit/#examples","title":"Examples","text":"<p>See the following interactive notebook for a practical example:</p> <ul> <li>Basic Usage Example - Introduction to the WereRabbit model</li> </ul>"},{"location":"neuron_models/wererabbit/wererabbit/","title":"Basic example","text":"In\u00a0[1]: Copied! <pre>import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrand\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nfrom eventpropjax.evnn import FFEvNN\n\nfrom felice.neuron_models import WereRabbit\n\njax.config.update(\"jax_enable_x64\", True)\n</pre> import diffrax as dfx import jax import jax.numpy as jnp import jax.random as jrand import matplotlib as mpl import matplotlib.pyplot as plt from eventpropjax.evnn import FFEvNN from felice.neuron_models import WereRabbit jax.config.update(\"jax_enable_x64\", True) In\u00a0[2]: Copied! <pre>key = jrand.key(0)\nmax_time = 100\n\ninit_weights = jnp.asarray([[0.0], [0.2], [-0.1]])\nsnn = FFEvNN(\n layers=[1],\n in_size=2,\n neuron_model=WereRabbit,\n solver=dfx.Tsit5(),\n # stepsize_controller=PIDController(\n # rtol=1e-6, atol=1e-3, pcoeff=0.1, icoeff=0.3, dcoeff=0.0\n # ),\n max_solver_time=max_time,\n key=key,\n max_event_steps=10000,\n solver_stepsize=0.001,\n init_weights=init_weights,\n dtype=jnp.float64,\n)\n</pre> key = jrand.key(0) max_time = 100 init_weights = jnp.asarray([[0.0], [0.2], [-0.1]]) snn = FFEvNN( layers=[1], in_size=2, neuron_model=WereRabbit, solver=dfx.Tsit5(), # stepsize_controller=PIDController( # rtol=1e-6, atol=1e-3, pcoeff=0.1, icoeff=0.3, dcoeff=0.0 # ), max_solver_time=max_time, key=key, max_event_steps=10000, solver_stepsize=0.001, init_weights=init_weights, dtype=jnp.float64, ) In\u00a0[3]: Copied! <pre>in_spikes = jnp.asarray([[10, 30], [20, 21]])\ncomp_times = jnp.linspace(0.0, max_time, 2000)\nstate = snn.state_at_t(in_spikes, comp_times)\nspikes = snn.spikes_until_t(in_spikes, max_time)\n</pre> in_spikes = jnp.asarray([[10, 30], [20, 21]]) comp_times = jnp.linspace(0.0, max_time, 2000) state = snn.state_at_t(in_spikes, comp_times) spikes = snn.spikes_until_t(in_spikes, max_time) In\u00a0[4]: Copied! <pre>def compute_nullclines(snn, u_range, v_range, resolution=200):\n \"\"\"\n Compute nullclines\n du/dt = 0 (u-nullcline)\n dv/dt = 0 (v-nullcline)\n \"\"\"\n u_vals = jnp.linspace(u_range[0], u_range[1], resolution)\n v_vals = jnp.linspace(v_range[0], v_range[1], resolution)\n U, V = jnp.meshgrid(u_vals, v_vals)\n\n UV = jnp.stack(\n [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1\n )\n dS = snn.neuron_model.vector_field(UV)\n dU = dS[:, 0].reshape(U.shape)\n dV = dS[:, 1].reshape(V.shape)\n\n return U, V, dU, dV\n\n\ndef plot_vf(ax, snn, u_range, v_range):\n import numpy as np\n\n u_sparse = jnp.linspace(u_range[0], u_range[1], 20)\n v_sparse = jnp.linspace(v_range[0], v_range[1], 20)\n\n Us, Vs = jnp.meshgrid(u_sparse, v_sparse)\n\n U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200)\n\n UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1)\n dS = snn.neuron_model.vector_field(UVs)\n dUs = dS[:, 0].reshape(Us.shape)\n dVs = dS[:, 1].reshape(Vs.shape)\n\n # Normalize for visualization\n magnitude = np.sqrt(dUs**2 + dVs**2)\n magnitude[magnitude == 0] = 1\n dUs_norm = dUs / magnitude\n dVs_norm = dVs / magnitude\n\n # Nullclines\n ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\")\n ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\")\n\n # Vector field\n ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6)\n\n ax.set_xlabel(\"u (Prey)\")\n ax.set_ylabel(\"v (Predator)\")\n ax.set_title(\"Wererabbit: Phase Portrait\")\n ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\")\n ax.set_xlim(u_range)\n ax.set_ylim(v_range)\n ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3)\n</pre> def compute_nullclines(snn, u_range, v_range, resolution=200): \"\"\" Compute nullclines du/dt = 0 (u-nullcline) dv/dt = 0 (v-nullcline) \"\"\" u_vals = jnp.linspace(u_range[0], u_range[1], resolution) v_vals = jnp.linspace(v_range[0], v_range[1], resolution) U, V = jnp.meshgrid(u_vals, v_vals) UV = jnp.stack( [U.reshape(-1), V.reshape(-1), jnp.ones((resolution * resolution,))], axis=1 ) dS = snn.neuron_model.vector_field(UV) dU = dS[:, 0].reshape(U.shape) dV = dS[:, 1].reshape(V.shape) return U, V, dU, dV def plot_vf(ax, snn, u_range, v_range): import numpy as np u_sparse = jnp.linspace(u_range[0], u_range[1], 20) v_sparse = jnp.linspace(v_range[0], v_range[1], 20) Us, Vs = jnp.meshgrid(u_sparse, v_sparse) U, V, dU, dV = compute_nullclines(snn, u_range, v_range, 200) UVs = jnp.stack([Us.reshape(-1), Vs.reshape(-1), jnp.ones((20 * 20,))], axis=1) dS = snn.neuron_model.vector_field(UVs) dUs = dS[:, 0].reshape(Us.shape) dVs = dS[:, 1].reshape(Vs.shape) # Normalize for visualization magnitude = np.sqrt(dUs**2 + dVs**2) magnitude[magnitude == 0] = 1 dUs_norm = dUs / magnitude dVs_norm = dVs / magnitude # Nullclines ax.contour(U, V, dU, levels=[0], colors=\"blue\", linewidths=1, linestyles=\"-\") ax.contour(U, V, dV, levels=[0], colors=\"red\", linewidths=1, linestyles=\"-\") # Vector field ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap=\"viridis\", alpha=0.6) ax.set_xlabel(\"u (Prey)\") ax.set_ylabel(\"v (Predator)\") ax.set_title(\"Wererabbit: Phase Portrait\") ax.legend([\"u-nullcline (du/dt=0)\", \"v-nullcline (dv/dt=0)\"], loc=\"upper right\") ax.set_xlim(u_range) ax.set_ylim(v_range) ax.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.3) ax.axvline(x=0, color=\"gray\", linestyle=\"--\", alpha=0.3) In\u00a0[5]: Copied! <pre>with mpl.style.context(\"boilerplot.ieeetran\"):\n fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200)\n ax[0].plot(comp_times, state[0, :, 0], label=\"x1\")\n ax[0].plot(comp_times, state[0, :, 1], label=\"x2\")\n [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)]\n ax[0].legend([\"x1\", \"x2\", \"Spike\"])\n\n plot_vf(ax[1], snn, [-0.2, 0.5], [-0.2, 0.5])\n ax[1].plot(state[0, :, 0], state[0, :, 1])\n ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\")\n ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\")\n ax[1].legend()\n plt.show()\n</pre> with mpl.style.context(\"boilerplot.ieeetran\"): fig, ax = plt.subplots(1, 2, figsize=(6.9, 2.6), dpi=200) ax[0].plot(comp_times, state[0, :, 0], label=\"x1\") ax[0].plot(comp_times, state[0, :, 1], label=\"x2\") [ax[0].axvline(s, alpha=0.2, color=\"g\", linestyle=\"--\") for s in jnp.unique(spikes)] ax[0].legend([\"x1\", \"x2\", \"Spike\"]) plot_vf(ax[1], snn, [-0.2, 0.5], [-0.2, 0.5]) ax[1].plot(state[0, :, 0], state[0, :, 1]) ax[1].plot(state[0, 0, 0], state[0, 0, 1], \".\", label=\"start\") ax[1].plot(state[0, -1, 0], state[0, -1, 1], \".\", label=\"end\") ax[1].legend() plt.show() In\u00a0[6]: Copied! <pre>import equinox as eqx\n\n\ndef loss_fn(model, spike_times, comp_times):\n out_states = model.state_at_t(spike_times, comp_times)\n logits = out_states[:, :, 0]\n\n return jnp.sum(logits)\n\n\nin_spikes = jnp.asarray([[0.01], [0.157]])\ncomp_times = jnp.linspace(0.0, max_time, 10)\n\nloss, gradients = eqx.filter_value_and_grad(loss_fn)(snn, in_spikes, comp_times)\nprint(\"Loss \", loss)\nprint(\"Gradients \", gradients.neuron_model.weight_u)\n</pre> import equinox as eqx def loss_fn(model, spike_times, comp_times): out_states = model.state_at_t(spike_times, comp_times) logits = out_states[:, :, 0] return jnp.sum(logits) in_spikes = jnp.asarray([[0.01], [0.157]]) comp_times = jnp.linspace(0.0, max_time, 10) loss, gradients = eqx.filter_value_and_grad(loss_fn)(snn, in_spikes, comp_times) print(\"Loss \", loss) print(\"Gradients \", gradients.neuron_model.weight_u) <pre>Loss 2.78585239490824\nGradients [[ 0. ]\n [-1.81404788]\n [-1.42144198]]\n</pre> In\u00a0[\u00a0]: Copied! <pre>\n</pre>"}]} |