Felice
This project provides a JAX implementation of the different neuron models in felice
Overview
The framework is built on top of EventPropJax and leverages JAX's automatic differentiation for efficient simulation and training of SNNs using event-based exact gradients.
Key Features
- Delay learning
- Non-linear neuron models
- WereRabbit Neuron Model: Implementation of a dual-state oscillatory neuron model with bistable dynamics
- FHN Neuron Model
- Snowball Neuron Model
📦 Installation
Felice uses uv for dependency management. To install:
CUDA Support (Optional)
For GPU acceleration with CUDA 13:
Quick Start
Here's a simple example using the WereRabbit neuron model:
import diffrax as dfx
import jax.numpy as jnp
import jax.random as jrand
from eventpropjax.evnn import FFEvNN
from felice.neuron_models import WereRabbit
# Initialize random key and parameters
key = jrand.key(0)
max_time = 300e-3
# Create a feedforward event-driven neural network
snn = FFEvNN(
layers=[1],
in_size=2,
neuron_model=WereRabbit,
solver=dfx.Tsit5(),
max_solver_time=max_time,
key=key,
max_event_steps=1000000,
solver_stepsize=1e-6,
rtol=10.0,
atol=0.0,
Ibias=30e-12,
)
# Simulate with input spikes
in_spikes = jnp.asarray([[0.0], [0.157]])
spikes = snn.spikes_until_t(in_spikes, max_time)
See the examples directory for more detailed usage examples.