Initial commit

Co-authored-by: Aradhana Dube <a.dube@rug.nl>
Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl>
Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl>
Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
This commit is contained in:
2026-02-26 18:30:32 +01:00
commit 9fabbdefc0
75 changed files with 447515 additions and 0 deletions

0
src/felice/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,172 @@
from io import BytesIO
from pathlib import Path
from typing import Optional, Sequence
from urllib.request import urlopen
from zipfile import ZipFile
import jax.numpy as jnp
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
letters = [
"Space",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
def load(
thr: int = 1,
taxels: Optional[Sequence[int]] = None,
custom_letters: Optional[Sequence[str]] = None,
):
file_name = Path(f"./braille/data_braille_letters_th_{thr}.pkl")
if not file_name.exists():
resp = urlopen(
"https://zenodo.org/records/7050094/files/reading_braille_data.zip"
)
with ZipFile(BytesIO(resp.read())) as zObject:
zObject.extract(f"data_braille_letters_th_{thr}.pkl", path="./braille")
data_dict = pd.read_pickle(file_name)
# Extract data
data_list: list[np.ndarray] = []
label_list: list[int] = []
nchan = len(data_dict["events"][1]) # number of channels per sensor
max_events = 0
for i, sample in enumerate(data_dict["events"]):
for taxel in range(len(sample)):
for event_type in range(len(sample[taxel])):
events = sample[taxel][event_type]
max_events = len(events) if len(events) > max_events else max_events
for i, sample in enumerate(data_dict["events"]):
events_array = np.full([nchan, max_events, 2], np.inf)
for taxel in range(len(sample)):
# loop over On and Off channels
for event_type in range(len(sample[taxel])):
events = sample[taxel][event_type]
if events:
events_array[taxel, : len(events), event_type] = events # ms
if taxels is not None:
events_array = np.reshape(
np.transpose(events_array, (0, 2, 1))[taxels, :, :],
(-1, events_array.shape[1]),
)
selected_chans = 2 * len(taxels)
else:
events_array = np.reshape(
np.transpose(events_array, (0, 2, 1)), (-1, events_array.shape[1])
)
selected_chans = 2 * nchan
lbl = data_dict["letter"][i]
if custom_letters is not None:
if lbl in custom_letters:
data_list.append(events_array)
label_list.append(custom_letters.index(lbl))
else:
data_list.append(events_array)
label_list.append(letters.index(lbl))
data = np.stack(data_list) * 100 # To ms
labels = np.stack(label_list)
nb_outputs = len(np.unique(labels))
x_train, x_test, y_train, y_test = train_test_split(
data, labels, test_size=0.30, shuffle=True, stratify=labels
)
x_test, x_validation, y_test, y_validation = train_test_split(
x_test, y_test, test_size=0.33, shuffle=True, stratify=y_test
)
trainset = (jnp.asarray(x_train), jnp.asarray(y_train))
testset = (jnp.asarray(x_test), jnp.asarray(y_test))
valset = (jnp.asarray(x_validation), jnp.asarray(y_validation))
return trainset, testset, valset, selected_chans, nb_outputs
def load_raw(
taxels: Optional[Sequence[int]] = None,
custom_letters: Optional[Sequence[str]] = None,
):
file_name = Path("./braille/data_braille_letters_digits.pkl")
if not file_name.exists():
resp = urlopen(
"https://zenodo.org/records/7050094/files/reading_braille_data.zip"
)
with ZipFile(BytesIO(resp.read())) as zObject:
zObject.extract("data_braille_letters_digits.pkl", path="./braille")
data_dict = pd.read_pickle(file_name)
# Extract data
data_list: list[np.ndarray] = []
label_list: list[int] = []
nchan = data_dict["taxel_data"][0].shape[1] # number of channels per sensor
for i, sample in enumerate(data_dict["taxel_data"]):
if taxels is not None:
sample = sample[:, taxels]
selected_chans = len(taxels)
else:
selected_chans = nchan
lbl = data_dict["letter"][i]
if custom_letters is not None:
if lbl in custom_letters:
data_list.append(sample)
label_list.append(custom_letters.index(lbl))
else:
data_list.append(sample)
label_list.append(letters.index(lbl))
data = np.stack(data_list) # To ms
labels = np.stack(label_list)
nb_outputs = len(np.unique(labels))
x_train, x_test, y_train, y_test = train_test_split(
data, labels, test_size=0.30, shuffle=True, stratify=labels
)
x_test, x_validation, y_test, y_validation = train_test_split(
x_test, y_test, test_size=0.33, shuffle=True, stratify=y_test
)
trainset = (jnp.asarray(x_train), jnp.asarray(y_train))
testset = (jnp.asarray(x_test), jnp.asarray(y_test))
valset = (jnp.asarray(x_validation), jnp.asarray(y_validation))
return trainset, testset, valset, selected_chans, nb_outputs

View File

@@ -0,0 +1,125 @@
from typing import Tuple
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
class ReasoningDataset:
"""
Task types:
1. Simple comparison: IF A > B THEN x ELSE y
2. Accumulated conditions: SET A, ADD B, IF SUM > threshold THEN x ELSE y
"""
# Token vocabulary
VOCAB = {
"PAD": 0,
"SET_A": 1,
"SET_B": 2,
"SET_C": 3,
"IF_A>B": 4,
"IF_A<B": 5,
"IF_SUM>": 6,
"THEN": 7,
"ELSE": 8,
"QUERY": 9,
"ADD": 10,
}
NUM_OFFSET: int = 11
VOCAB_SIZE: int = 31
NUM_OUTPUT: int = 16
def generate_simple_comparison(self, key: PRNGKeyArray) -> Tuple[Array, Array]:
"""
Generate: [SET_A, a, SET_B, b, IF_A>B, THEN, x, ELSE, y, QUERY]
Target at QUERY position: x if a > b else y
"""
keys = jax.random.split(key, 4)
a = jax.random.randint(keys[0], (), 0, self.NUM_OUTPUT - 1)
b = jax.random.randint(keys[1], (), 0, self.NUM_OUTPUT - 1)
x = jax.random.randint(keys[2], (), 0, self.NUM_OUTPUT - 1)
y = jax.random.randint(keys[3], (), 0, self.NUM_OUTPUT - 1)
# TODO: Generate other sequences
input_seq = jnp.array(
[
self.VOCAB["SET_A"],
self.NUM_OFFSET + a,
self.VOCAB["SET_B"],
self.NUM_OFFSET + b,
self.VOCAB["IF_A>B"],
self.VOCAB["THEN"],
self.NUM_OFFSET + x,
self.VOCAB["ELSE"],
self.NUM_OFFSET + y,
self.VOCAB["QUERY"],
]
)
result = jnp.where(a > b, x, y)
target = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, result])
mask = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
return input_seq, target, mask
def generate_accumulation_condition(self, key: PRNGKeyArray) -> Tuple[Array, Array]:
"""
Generate: [SET_A, a, ADD, b, ADD, c, IF_SUM>, threshold, THEN, x, ELSE, y, QUERY]
Requires accumulating a + b + c and comparing to threshold.
"""
keys = jax.random.split(key, 6)
a = jax.random.randint(keys[0], (), 0, 10)
b = jax.random.randint(keys[1], (), 0, 10)
c = jax.random.randint(keys[2], (), 0, 10)
threshold = jax.random.randint(keys[3], (), 5, 25)
x = jax.random.randint(keys[4], (), 0, 15)
y = jax.random.randint(keys[5], (), 0, 15)
# TODO: Generate other sequences
input_seq = jnp.array(
[
self.VOCAB["SET_A"],
self.NUM_OFFSET + a,
self.VOCAB["ADD"],
self.NUM_OFFSET + b,
self.VOCAB["ADD"],
self.NUM_OFFSET + c,
self.VOCAB["IF_SUM>"],
self.NUM_OFFSET + threshold,
self.VOCAB["THEN"],
self.NUM_OFFSET + x,
self.VOCAB["ELSE"],
self.NUM_OFFSET + y,
self.VOCAB["QUERY"],
]
)
total = a + b + c
result = jnp.where(total > threshold, x, y)
target = jnp.zeros(13, dtype=jnp.int32).at[-1].set(result)
mask = jnp.zeros(13, dtype=jnp.int32).at[-1].set(1)
return input_seq, target, mask
def generate_batch(
self, key: PRNGKeyArray, batch_size: int, task_type: str = "simple"
) -> Tuple[Array, Array, Array]:
keys = jax.random.split(key, batch_size)
if task_type == "simple":
gen_fn = self.generate_simple_comparison
else: # accumulation
gen_fn = self.generate_accumulation_condition
inputs, targets, masks = [], [], []
for k in keys:
inp, tgt, msk = gen_fn(k)
inputs.append(inp)
targets.append(tgt)
masks.append(msk)
return jnp.stack(inputs), jnp.stack(targets), jnp.stack(masks)

View File

@@ -0,0 +1,45 @@
from typing import Generic, Protocol, TypeVar
import equinox as eqx
import jax
import jax.random as jrandom
from jaxtyping import Array, Float, PRNGKeyArray
from .implicit import Boomerang, Implicit
from .ssm import Mamba
__all__ = ["Mamba", "Implicit", "Boomerang", "SequenceClassifier"]
T = TypeVar("T")
class HasSequential(Protocol):
def sequential(self, x): ...
class SequenceClassifier(eqx.Module, Generic[T]):
embedding: eqx.nn.Embedding
model: eqx.Module
def __init__(
self,
vocab_size: int,
d_model: int,
model_class: T,
key=PRNGKeyArray,
**model_kwargs: dict,
):
keys = jrandom.split(key, 2)
self.embedding = eqx.nn.Embedding(vocab_size, d_model, key=keys[0])
self.model = model_class(d_model=d_model, key=keys[1], **model_kwargs)
def __call__(self, input_ids: Float[Array, " seq"]) -> Float[Array, "seq d_model"]:
x = jax.vmap(self.embedding)(input_ids)
y = self.model(x)
return y
def sequential(
self: "SequenceClassifier[HasSequential]", input_ids: Float[Array, " seq"]
) -> Float[Array, "seq d_model"]:
x = jax.vmap(self.embedding)(input_ids)
return self.model.sequential(x)

View File

@@ -0,0 +1,4 @@
from .base import Implicit
from .boomerang import Boomerang
__all__ = ["Implicit", "Boomerang"]

View File

@@ -0,0 +1,210 @@
from typing import Callable, Optional
import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optimistix as optx
from diffrax._custom_types import RealScalarLike
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from ..utils import binary_op
def _depth_step(_, z_prev, args):
x, model = args
# ── Eq (6): h_t^{(s)} = Λ(z_t^{(s-1)}, x_t) * h_{t-1}^{(s)} + u(z_t^{(s-1)}, x_t)
lambda_vals = jax.vmap(model.compute_lambda)(z_prev, x)
u_vals = jax.vmap(model.compute_u)(z_prev, x)
_, h = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
# ── Eq (7): z_t^{(s)} = f_θ(z_t^{(s-1)}, h_{t-1}^{(s)}, x_t)
# z depends on h_{t-1}^{(s)}, i.e. the hidden state of the
# PREVIOUS token h_{t-1}^{(s)}, NOT the just-computed h_t^{(s)}.
h0 = jnp.zeros((1, model.d_state))
h = jnp.concatenate([h0, h[:-1]], axis=0)
dz = jax.vmap(model.f_theta)(z_prev, h, x)
return dz
Normalizer = Callable[[PyTree[Array]], RealScalarLike]
class Implicit(eqx.Module):
d_model: int = eqx.field(static=True)
d_state: int = eqx.field(static=True)
d_inner: int = eqx.field(static=True)
dt: float = eqx.field(static=True)
max_time: float = eqx.field(static=True)
max_iters: int | None = eqx.field(static=True)
rtol: float = eqx.field(static=True)
atol: float = eqx.field(static=True)
norm: Normalizer = eqx.field(static=True)
adjoint: dfx.AbstractAdjoint
solver: dfx.AbstractSolver
f_net: eqx.nn.Linear
lambda_net: eqx.nn.Linear # Λ: maps (z, x) → decay factor (diagonal)
u_net: eqx.nn.Linear # u: maps (z, x) → input state
out_net: eqx.nn.Linear # Output: maps (z, h) → y
def __init__(
self,
d_model: int,
d_state: int = 16,
d_inner: int = 32,
dt: float = 1.0,
max_iters: int | None = 100,
max_time: float = 100,
solver: Optional[dfx.AbstractSolver] = None,
adjoint: Optional[dfx.AbstractAdjoint] = None,
rtol: float = 1e-4,
atol: float = 1e-3,
norm: Normalizer = optx.rms_norm,
*,
key: PRNGKeyArray,
):
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_inner
self.dt = dt
self.max_time = max_time
self.max_iters = max_iters
self.solver = solver if solver is not None else dfx.Tsit5()
self.adjoint = adjoint if adjoint is not None else dfx.ImplicitAdjoint()
self.rtol = rtol
self.atol = atol
self.norm = norm
keys = jrandom.split(key, 4)
self.f_net = eqx.nn.Linear(
d_inner + d_state + d_model,
d_inner,
key=keys[0],
)
self.lambda_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[1])
self.u_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[2])
self.out_net = eqx.nn.Linear(d_inner, d_model, key=keys[3])
def compute_lambda(
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
) -> Float[Array, " d_state"]:
zx = jnp.concatenate([z, x], axis=-1)
# Sigmoid to keep in (0, 1) for stability
return jax.nn.sigmoid(self.lambda_net(zx))
def compute_u(
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
) -> Float[Array, " d_state"]:
zx = jnp.concatenate([z, x], axis=-1)
return self.u_net(zx)
def f_theta(
self,
z: Float[Array, " d_inner"],
h: Float[Array, " d_state"],
x: Float[Array, " d_model"],
) -> Float[Array, " d_inner"]:
"""f_θ(z, h, x) → z - the implicit function."""
zhx = jnp.concatenate([z, h, x])
dz = jax.nn.silu(self.f_net(zhx)) - z
return dz
def get_z(
self,
x: Float[Array, "seq d_model"],
t0: float = 0,
t1: float | None = None,
num_points: int = 100,
) -> Float[Array, "seq d_model"]:
seq_len, _ = x.shape
t1 = t1 if t1 is not None else self.max_time
z0 = jnp.zeros((seq_len, self.d_inner))
terms = dfx.ODETerm(_depth_step)
sol = dfx.diffeqsolve(
terms,
self.solver,
t0=0.0,
t1=t1,
dt0=self.dt,
y0=z0,
args=(x, self),
max_steps=self.max_iters,
saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, num_points)),
)
return sol.ts, sol.ys
def sequential(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
def scan_fn(h, x_t):
def depth_step(_, z_prev, args):
h, x, model = args
z_s = model.f_theta(z_prev, h, x)
return z_s
z0 = jnp.zeros((self.d_inner,))
cond_fn = dfx.steady_state_event(rtol=1e-4, atol=1e-3, norm=optx.rms_norm)
event = dfx.Event(cond_fn)
terms = dfx.ODETerm(depth_step)
sol = dfx.diffeqsolve(
terms,
self.solver,
t0=0.0,
t1=self.max_time,
dt0=self.dt,
y0=z0,
args=(h, x_t, self),
max_steps=self.max_iters,
event=event,
adjoint=self.adjoint,
)
z_star = sol.ys[-1]
# Compute new hidden state
lambda_val = self.compute_lambda(z_star, x_t) # (d_state,)
u_val = self.compute_u(z_star, x_t) # (d_state,)
h_new = lambda_val * h + u_val
y = self.out_net(z_star)
return h_new, y
h_init = jnp.zeros(self.d_state)
_, y = jax.lax.scan(scan_fn, h_init, x)
return y
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
seq_len, _ = x.shape
z0 = jnp.zeros((seq_len, self.d_inner))
cond_fn = dfx.steady_state_event(rtol=1e-4, atol=1e-3, norm=optx.rms_norm)
event = dfx.Event(cond_fn)
terms = dfx.ODETerm(_depth_step)
sol = dfx.diffeqsolve(
terms,
self.solver,
t0=0.0,
t1=self.max_time,
dt0=self.dt,
y0=z0,
args=(x, self),
max_steps=self.max_iters,
event=event,
adjoint=self.adjoint,
)
z_star = sol.ys[-1]
y_star = jax.vmap(self.out_net)(z_star)
return y_star

View File

@@ -0,0 +1,346 @@
from typing import Optional
import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optimistix as optx
from jaxtyping import Array, Float, PRNGKeyArray
from ..utils import binary_op
from .base import Implicit, Normalizer
class ImplicitBoomerang(eqx.Module):
d_model: int = eqx.field(static=True)
d_state: int = eqx.field(static=True)
d_inner: int = eqx.field(static=True)
max_iters: int = eqx.field(static=True)
tol: float = eqx.field(static=True)
dt: float = eqx.field(static=True)
with_thr: bool = eqx.field(static=True)
debug: bool = eqx.field(static=True)
f_net: eqx.nn.Linear
lambda_net: eqx.nn.Linear # Λ: maps (z, x) → decay factor (diagonal)
u_net: eqx.nn.Linear # u: maps (z, x) → input state
out_net: eqx.nn.Linear # Output: maps (z, h) → y
def __init__(
self,
d_model: int,
d_state: int = 16,
d_inner: int = 32,
max_iters: int = 20,
tol: float = 1e-5,
dt: float = 1e-3,
with_thr: bool = True,
debug: bool = False,
*,
key: PRNGKeyArray,
):
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_inner
self.max_iters = max_iters
self.tol = tol
self.dt = dt
self.with_thr = with_thr
self.debug = debug
keys = jrandom.split(key, 6)
self.f_net = eqx.nn.Linear(
d_state + d_model,
d_inner,
key=keys[1],
)
self.lambda_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[2])
self.u_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[3])
self.out_net = eqx.nn.Linear(d_inner + d_state, d_model, key=keys[4])
def compute_lambda(
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
) -> Float[Array, " d_state"]:
zx = jnp.concatenate([z, x], axis=-1)
# Sigmoid to keep in (0, 1) for stability
return jax.nn.sigmoid(jax.vmap(self.lambda_net)(zx))
def compute_u(
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
) -> Float[Array, " d_state"]:
zx = jnp.concatenate([z, x], axis=-1)
return jax.vmap(self.u_net)(zx)
def f_theta(
self,
z: Float[Array, " d_inner"],
h: Float[Array, " d_state"],
x: Float[Array, " d_model"],
) -> Float[Array, " d_inner"]:
"""f_θ(z, h, x) → z - the implicit function."""
hx = jnp.concatenate([h, x])
alpha_sigma = jnp.split(self.f_net(hx), 2)
rho = 30.0
alpha = jax.nn.sigmoid(alpha_sigma[0])
beta = 15.6
gamma = 0.26
sigma = jax.nn.sigmoid(alpha_sigma[1])
u, v = jnp.split(z, 2)
def true_fn(u, v):
return jax.nn.tanh(rho * (v - u))
def false_fn(u, v):
return jnp.ones_like(u)
thresh = jax.lax.cond(self.with_thr, true_fn, false_fn, u, v)
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * thresh
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * thresh
dz = jnp.concat([du, dv])
z = z + self.dt * dz
# kk = 0.68 # fixed by tech
# Ut = 0.025 # temperature dependent
# I_r0 = 0.9
# x1, x2 = jnp.split(z, 2)
# alpha = 0.000129 + (0.0129 - 0.000129) * jax.nn.sigmoid(
# alpha_sigma[0]
# ) # The circuit will get directly the current
# Ia = jax.nn.tanh(alpha_sigma[1]) * 0.6
# x3 = 1 # (np.tanh(20 * (x2 - x1))) #smoother transition
# dx1 = 2.3 * (
# 1 - (alpha * jnp.exp(kk * x2 / Ut)) * (1 - I_r0 * (0.3 - x1)) + Ia * x3
# )
# dx2 = 2.3 * (
# -1 + (alpha * jnp.exp(kk * x1 / Ut)) * (1 + I_r0 * (0.3 - x2)) + Ia * x3
# )
# dz = jnp.concat([dx1, dx2])
# z = z + self.dt * dz
return z
def get_z(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
seq_len, _ = x.shape
def body_fn(state, _):
z = state
lambda_vals = self.compute_lambda(z, x)
u_vals = self.compute_u(z, x)
_, h = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
return z_new, z_new
# Initialize
z_init = jnp.zeros((seq_len, self.d_inner))
# Run fixed-point iteration
_, z = jax.lax.scan(
body_fn,
z_init,
None,
length=self.max_iters,
)
return z
def sequential(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
def scan_fn(h, x_t):
z_init = jnp.zeros((self.d_inner,))
z_prev_init = jnp.full((self.d_inner,), jnp.inf)
def cond_fn(state):
i, z, z_prev = state
converged = (
jnp.linalg.norm(z - z_prev) / jnp.linalg.norm(z_prev) < 0.001
)
return (i < self.max_iters) & ~converged
def body_fn(state):
i, z, _ = state
z_new = self.f_theta(z, h, x_t)
return (i + 1, z_new, z)
# Run fixed-point iteration
if self.debug:
def scan_fn(state, _):
new_state = body_fn(state)
return new_state, new_state[1]
_, z_star_debug = jax.lax.scan(
scan_fn, (0, z_init, z_prev_init), None, length=self.max_iters
)
z_star = z_star_debug[-1]
else:
_, z_star, _ = eqx.internal.while_loop(
cond_fn,
body_fn,
(0, z_init, z_prev_init),
max_steps=self.max_iters,
kind="bounded",
)
# Compute new hidden state
lambda_val = self.compute_lambda(
z_star[jnp.newaxis, :], x_t[jnp.newaxis, :]
)[0] # (d_state,)
u_val = self.compute_u(z_star[jnp.newaxis, :], x_t[jnp.newaxis, :])[
0
] # (d_state,)
h_new = lambda_val * h + u_val
zh = jnp.concatenate([z_star, h_new])
y = self.out_net(zh)
return h_new, (y, z_star_debug)
h_init = jnp.zeros(self.d_state)
_, (y, z_star) = jax.lax.scan(scan_fn, h_init, x)
if self.debug:
return y, z_star
else:
return y
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
seq_len, _ = x.shape
def cond_fn(state):
i, z, z_prev = state
converged = jnp.linalg.norm(z - z_prev) / jnp.linalg.norm(z_prev) < 0.001
return (i < self.max_iters) & ~converged
def body_fn(state):
i, z, _ = state
lambda_vals = self.compute_lambda(z, x)
u_vals = self.compute_u(z, x)
_, h = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
return (i + 1, z_new, z)
# Initialize
z_init = jnp.zeros((seq_len, self.d_inner))
z_prev_init = jnp.full((seq_len, self.d_inner), jnp.inf)
# Run fixed-point iteration
if self.debug:
def scan_fn(state, _):
new_state = body_fn(state)
return new_state, new_state[1]
_, z_star_debug = jax.lax.scan(
scan_fn, (0, z_init, z_prev_init), None, length=self.max_iters
)
z_star = z_star_debug[-1]
else:
_, z_star, _ = eqx.internal.while_loop(
cond_fn,
body_fn,
(0, z_init, z_prev_init),
max_steps=self.max_iters,
kind="bounded",
)
lambda_vals = self.compute_lambda(z_star, x)
u_vals = self.compute_u(z_star, x)
_, h_star = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
zh = jnp.concatenate([z_star, h_star], axis=-1)
y_star = jax.vmap(self.out_net)(zh)
if self.debug:
return y_star, z_star_debug
else:
return y_star
class Boomerang(Implicit):
def __init__(
self,
d_model: int,
d_state: int = 16,
d_inner: int = 32,
dt: float = 1.0,
max_iters: int | None = 100,
max_time: float = 100,
solver: Optional[dfx.AbstractSolver] = None,
adjoint: Optional[dfx.AbstractAdjoint] = None,
rtol: float = 1e-4,
atol: float = 1e-3,
norm: Normalizer = optx.rms_norm,
*,
key: PRNGKeyArray,
):
keys = jrandom.split(key)
super(Boomerang, self).__init__(
d_model=d_model,
d_state=d_state,
d_inner=d_inner,
dt=dt,
max_iters=max_iters,
max_time=max_time,
solver=solver,
adjoint=adjoint,
rtol=rtol,
atol=atol,
norm=norm,
key=keys[0],
)
self.f_net = eqx.nn.Linear(
d_state + d_model,
d_inner,
key=keys[1],
)
def f_theta(
self,
z: Float[Array, " d_inner"],
h: Float[Array, " d_state"],
x: Float[Array, " d_model"],
) -> Float[Array, " d_inner"]:
"""f_θ(z, h, x) → z - the implicit function."""
hx = jnp.concatenate([h, x])
alpha_sigma = jnp.split(self.f_net(hx), 2)
kk = 0.68 # fixed by tech
Ut = 0.025 # temperature dependent
I_r0 = 0.9
x1, x2 = jnp.split(z, 2)
alpha = 0.000129 + (0.0129 - 0.000129) * jax.nn.sigmoid(
alpha_sigma[0]
) # The circuit will get directly the current
Ia = jax.nn.tanh(alpha_sigma[1]) * 0.6
x3 = 1 # (np.tanh(20 * (x2 - x1))) #smoother transition
dx1 = 2.3 * (
1 - (alpha * jnp.exp(kk * x2 / Ut)) * (1 - I_r0 * (0.3 - x1)) + Ia * x3
)
dx2 = 2.3 * (
-1 + (alpha * jnp.exp(kk * x1 / Ut)) * (1 + I_r0 * (0.3 - x2)) + Ia * x3
)
dz = jnp.concat([dx1, dx2])
return dz

View File

@@ -0,0 +1,420 @@
from typing import Tuple
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jax import flatten_util, lax
from jaxtyping import Array, Float, PRNGKeyArray
@jax.jit
def binary_op(
left: Tuple[Array, Array],
right: Tuple[Array, Array],
) -> Tuple[Array, Array]:
"""
(a1, b1) ∘ (a2, b2) = (a1·a2, a2·b1 + b2)
"""
a1, b1 = left
a2, b2 = right
return (a1 * a2, a2 * b1 + b2)
class ImplicitSNN(eqx.Module):
d_model: int
d_state: int
d_latent: int
max_iters: int
tol: float
embedding: eqx.nn.Embedding
f_net: eqx.Module # f_θ: maps (z, h, x) → z
f_net2: eqx.Module # f_θ: maps (z, h, x) → z
lambda_net: eqx.Module # Λ: maps (z, x) → decay factor (diagonal)
u_net: eqx.Module # u: maps (z, x) → input contribution
out_net: eqx.nn.Linear # Output: maps (z, h) → y
def __init__(
self,
vocab_size: int,
d_model: int,
d_state: int = 16,
d_latent: int = 32,
max_iters: int = 20,
tol: float = 1e-5,
*,
key: PRNGKeyArray,
):
self.d_model = d_model
self.d_state = d_state
self.d_latent = d_latent
self.max_iters = max_iters
self.tol = tol
keys = jrandom.split(key, 6)
self.embedding = eqx.nn.Embedding(vocab_size, d_model, key=keys[0])
# f_θ(z, h, x) → z
# Input: z (d_latent) + h (d_state) + x (d_model)
self.f_net = eqx.nn.Linear(
d_state + d_model,
d_latent // 2,
key=keys[1],
)
self.f_net2 = eqx.nn.Linear(
d_state + d_model,
d_latent // 2,
key=keys[5],
)
# Λ(z, x) → (d_state,) decay factors
self.lambda_net = eqx.nn.Linear(d_latent + d_model, d_state, key=keys[2])
# u(z, x) → (d_state,) input contribution
self.u_net = eqx.nn.Linear(d_latent + d_model, d_state, key=keys[3])
# Output projection
self.out_net = eqx.nn.Linear(d_latent + d_state, d_model, key=keys[4])
def compute_lambda(
self, z: Float[Array, " d_latent"], x: Float[Array, " d_model"]
) -> Float[Array, " d_state"]:
"""Λ(z, x) - the decay/retention factor."""
zx = jnp.concatenate([z, x], axis=-1)
# Sigmoid to keep in (0, 1) for stability
return jax.nn.sigmoid(jax.vmap(self.lambda_net)(zx))
def compute_u(
self, z: Float[Array, " d_latent"], x: Float[Array, " d_model"]
) -> Float[Array, " d_state"]:
"""u(z, x) - the input contribution."""
zx = jnp.concatenate([z, x], axis=-1)
return jax.vmap(self.u_net)(zx)
# TODO: Change for diffrax
def f_theta(
self,
z: Float[Array, " d_latent"],
h: Float[Array, " d_state"],
x: Float[Array, " d_model"],
) -> Float[Array, " d_latent"]:
"""f_θ(z, h, x) → z - the implicit function."""
hx = jnp.concatenate([h, x])
rho = 30.0
alpha = jax.nn.sigmoid(self.f_net(hx))
beta = 15.6
gamma = 0.26
sigma = jax.nn.sigmoid(self.f_net2(hx))
u, v = jnp.split(z, 2)
thresh = jax.nn.tanh(rho * (v - u))
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * thresh
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * thresh
dz = jnp.concat([du, dv])
z = z + 0.001 * dz
return z
# ==================== Parallel mode (training) ====================
def parallel_scan_h(
self,
lambda_vals: Float[Array, "seq d_state"],
u_vals: Float[Array, "seq d_state"],
) -> Float[Array, "seq d_state"]:
"""
Compute h sequence via parallel scan.
h_t = λ_t · h_{t-1} + u_t
This is the standard SSM recurrence, parallelizable via associative scan.
"""
# Elements: (λ_t, u_t)
# After scan: (cumulative λ, h_t)
_, h = lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
return h
def debug(self, x):
x = jax.vmap(self.embedding)(x) # (seq, d_model)
seq_len = x.shape[0]
def body_fn(state, _):
z, _ = state
# Compute λ, u
lambda_vals = self.compute_lambda(z, x)
u_vals = self.compute_u(z, x)
# Parallel scan for h
h = self.parallel_scan_h(lambda_vals, u_vals)
# Shift h
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
# Update z
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
return (z_new, z), z_new
# Initialize
z_init = jnp.zeros((seq_len, self.d_latent))
z_prev_init = jnp.full((seq_len, self.d_latent), jnp.inf)
_, z_star = jax.lax.scan(
body_fn, (z_init, z_prev_init), None, length=self.max_iters
)
return z_star
def forward_parallel(
self,
x: Float[Array, " seq"],
) -> Float[Array, "seq d_model"]:
"""
Parallel forward pass for training.
Algorithm:
1. Initialize z for all positions
2. Iterate until convergence:
a. Compute λ, u from current z (parallel over positions)
b. Compute h via parallel scan
c. Update z = f_θ(z, h_shifted, x) (parallel over positions)
3. Compute output from final z, h
Note: h_shifted means h_{t-1} for position t, so we shift h right.
"""
x = jax.vmap(self.embedding)(x) # (seq, d_model)
seq_len = x.shape[0]
def cond_fn(state):
i, z, z_prev = state
converged = jnp.linalg.norm(z - z_prev) < self.tol
return (i < self.max_iters) & ~converged
def body_fn(state):
i, z, _ = state
# Compute λ, u
lambda_vals = self.compute_lambda(z, x)
u_vals = self.compute_u(z, x)
# Parallel scan for h
h = self.parallel_scan_h(lambda_vals, u_vals)
# Shift h
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
# Update z
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
return (i + 1, z_new, z)
# Initialize
z_init = jnp.zeros((seq_len, self.d_latent))
z_prev_init = jnp.full((seq_len, self.d_latent), jnp.inf)
# Run fixed-point iteration
_, z_star, _ = eqx.internal.while_loop(
cond_fn,
body_fn,
(0, z_init, z_prev_init),
max_steps=self.max_iters,
kind="bounded",
)
# Final h computation with converged z
lambda_vals = self.compute_lambda(z_star, x)
u_vals = self.compute_u(z_star, x)
h_star = self.parallel_scan_h(lambda_vals, u_vals)
# Output
zh = jnp.concatenate([z_star, h_star], axis=-1)
y_star = jax.vmap(self.out_net)(zh)
return y_star
# ==================== Sequential mode (inference) ====================
def find_fixed_point(
self,
h_prev: Float[Array, " d_state"],
x: Float[Array, " d_model"],
) -> Float[Array, " d_latent"]:
"""
Find z* such that z* = f_θ(z*, h_prev, x)
using fixed-point iteration.
"""
def cond_fn(state):
i, z, z_prev = state
converged = jnp.linalg.norm(z - z_prev) < self.tol
return (i < self.max_iters) & ~converged
def body_fn(state):
i, z, _ = state
z_new = self.f_theta(z, h_prev, x)
return (i + 1, z_new, z)
# Initialize z
z_init = jnp.zeros(self.d_latent)
z_prev_init = jnp.full(self.d_latent, jnp.inf)
_, z_star, _ = eqx.internal.while_loop(
cond_fn,
body_fn,
(0, z_init, z_prev_init),
max_steps=self.max_iters,
kind="bounded",
)
return z_star
def step(
self,
h_prev: Float[Array, " d_state"],
x: Float[Array, " d_model"],
) -> Tuple[Float[Array, " d_state"], Float[Array, " d_model"]]:
"""
Single step of the implicit SSM.
1. Find z* via fixed-point iteration
2. Compute h* = Λ(z*, x) · h_prev + u(z*, x)
3. Compute output y
"""
# Find fixed point z*
z_star = self.find_fixed_point(h_prev, x)
# Compute new hidden state
lambda_val = self.compute_lambda(z_star[jnp.newaxis, :], x[jnp.newaxis, :])[
0
] # (d_state,)
u_val = self.compute_u(z_star[jnp.newaxis, :], x[jnp.newaxis, :])[
0
] # (d_state,)
h_new = lambda_val * h_prev + u_val
# Compute output
zh = jnp.concatenate([z_star, h_new])
y = self.out_net(zh)
return h_new, y
def forward_sequential(
self,
x: Float[Array, " seq"],
) -> Float[Array, "seq d_model"]:
"""
Sequential forward pass for inference.
Processes one token at a time, maintaining state.
"""
x = jax.vmap(self.embedding)(x)
def scan_fn(h, x_t):
h_new, y_t = self.step(h, x_t)
return h_new, y_t
h_init = jnp.zeros(self.d_state)
_, y = lax.scan(scan_fn, h_init, x)
return y
def __call__(
self,
x: Float[Array, "seq d_model"],
mode: str = "parallel",
) -> Float[Array, "seq d_model"]:
"""
Forward pass over sequence.
Note: This is inherently sequential due to the implicit nature.
"""
if mode == "parallel":
return self.forward_parallel(x)
else:
return self.forward_sequential(x)
if __name__ == "__main__":
import time
key = jrandom.key(42)
keys = jrandom.split(key, 2)
d_model, d_state, d_latent = 64, 16, 32
seq_len = 1080
model = ImplicitSNN(
vocab_size=16,
d_model=d_model,
d_state=d_state,
d_latent=d_latent,
max_iters=10,
key=keys[0],
)
# Test input
x = jrandom.randint(keys[1], (seq_len,), 0, 15)
# Test parallel mode
print("Testing parallel mode...")
y_parallel = model(x, mode="parallel")
print(f" Input: {x.shape}, Output: {y_parallel.shape}")
# Test sequential mode
print("\nTesting sequential mode...")
y_sequential = model(x, mode="sequential")
print(f" Input: {x.shape}, Output: {y_sequential.shape}")
# Compare outputs (should be close but not exact due to different convergence paths)
diff = jnp.linalg.norm(y_parallel - y_sequential) / jnp.linalg.norm(y_sequential)
print(f"\nRelative difference: {diff:.6f}")
# Test gradients in parallel mode
def loss_fn(model, x, mode):
return jnp.mean(model(x, mode=mode) ** 2)
print("\nTesting gradients (parallel mode)...")
grads_par = eqx.filter_grad(loss_fn)(model, x, "parallel")
grads_par = flatten_util.ravel_pytree(grads_par)[0]
print("\nTesting gradients (sequential mode)...")
grads_seq = eqx.filter_grad(loss_fn)(model, x, "sequential")
grads_seq = flatten_util.ravel_pytree(grads_seq)[0]
grad_diff = jnp.linalg.norm(grads_par - grads_seq) / jnp.linalg.norm(grads_seq)
print(" Gradients computed successfully!")
print(f"\nRelative difference: {grad_diff:.6f}")
# Benchmark
print("\nBenchmarking...")
# JIT compile
forward_parallel_jit = eqx.filter_jit(lambda m, x: m(x, mode="parallel"))
forward_sequential_jit = eqx.filter_jit(lambda m, x: m(x, mode="sequential"))
# Warmup
_ = forward_parallel_jit(model, x)
_ = forward_sequential_jit(model, x)
# Time parallel
start = time.time()
for _ in range(100):
_ = forward_parallel_jit(model, x).block_until_ready()
parallel_time = (time.time() - start) / 100
# Time sequential
start = time.time()
for _ in range(100):
_ = forward_sequential_jit(model, x).block_until_ready()
sequential_time = (time.time() - start) / 100
print(f" Parallel: {parallel_time * 1000:.3f} ms")
print(f" Sequential: {sequential_time * 1000:.3f} ms")
print(f" Speedup: {sequential_time / parallel_time:.2f}x")

137
src/felice/networks/ssm.py Normal file
View File

@@ -0,0 +1,137 @@
import math
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
from einops import repeat
from jaxtyping import Array, Float, PRNGKeyArray
from .utils import binary_op
class SelectiveSSM(eqx.Module):
d_model: int = eqx.field(static=True)
d_state: int = eqx.field(static=True)
dt_rank: int = eqx.field(static=True)
x_proj: eqx.nn.Linear
dt_proj: eqx.nn.Linear
A_log: Float[Array, "d_inner d_state"]
D: Float[Array, " d_inner"]
def __init__(
self,
d_model: int,
dt_rank: int,
d_state: int = 16,
*,
key=PRNGKeyArray,
):
self.d_model = d_model
self.d_state = d_state
self.dt_rank = dt_rank
keys = jrandom.split(key, 5)
self.x_proj = eqx.nn.Linear(
self.d_model, self.dt_rank + self.d_state * 2, use_bias=False, key=keys[3]
)
self.dt_proj = eqx.nn.Linear(
self.dt_rank, self.d_model, use_bias=True, key=keys[4]
)
# S4D-Real initialization
A = repeat(
jnp.arange(1, d_state + 1, dtype=jnp.float32),
"n -> d n",
d=self.d_model,
)
self.A_log = jnp.log(A)
self.D = jnp.ones(self.d_model)
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
x_dbl = jax.vmap(self.x_proj)(x)
dt, B, C = jnp.split(
x_dbl, [self.dt_rank, self.dt_rank + self.d_state], axis=-1
)
dt = jax.vmap(self.dt_proj)(dt)
dt = jax.nn.softplus(dt)
A = -jnp.exp(self.A_log) # (d_inner, d_state)
dA = jnp.exp(jnp.einsum("ld,dn->ldn", dt, A))
dB_x = jnp.einsum("ld,ln,ld->ldn", dt, B, x)
_, h = jax.lax.associative_scan(binary_op, (dA, dB_x), axis=0)
y = jnp.einsum("ldn,ln->ld", h, C)
y = y + x * self.D
return y
class Mamba(eqx.Module):
d_model: int = eqx.field(static=True)
d_conv: int = eqx.field(static=True)
d_inner: int = eqx.field(static=True)
in_proj: eqx.nn.Linear
out_proj: eqx.nn.Linear
conv1d: eqx.nn.Conv1d
ssm: SelectiveSSM
def __init__(
self,
d_model: int,
d_state: int = 16,
d_inner: int = 32,
d_conv: int = 4,
*,
key=PRNGKeyArray,
):
self.d_model = d_model
self.d_conv = d_conv
self.d_inner = d_inner
keys = jrandom.split(key, 4)
self.in_proj = eqx.nn.Linear(self.d_model, self.d_inner * 2, key=keys[0])
self.conv1d = eqx.nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
key=keys[1],
)
self.ssm = SelectiveSSM(
d_model=d_inner,
dt_rank=math.ceil(self.d_model / 16),
d_state=d_state,
key=keys[2],
)
self.out_proj = eqx.nn.Linear(self.d_inner, self.d_model, key=keys[3])
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
seq_len, _ = x.shape
# Projec the input into the convolution and residual
xz = jax.vmap(self.in_proj)(x)
x, z = jnp.split(xz, 2, axis=-1)
# 1D Convolution
x = x.T
x = self.conv1d(x)[:, :seq_len]
x = x.T
x = jax.nn.silu(x)
y = self.ssm(x)
y = y * jax.nn.silu(z)
logits = jax.vmap(self.out_proj)(y)
return logits

View File

@@ -0,0 +1,13 @@
from typing import Tuple
import jax
from jaxtyping import Array
@jax.jit
def binary_op(
left: Tuple[Array, Array], right: Tuple[Array, Array]
) -> Tuple[Array, Array]:
a1, b1 = left
a2, b2 = right
return (a1 * a2, a2 * b1 + b2)

View File

@@ -0,0 +1,5 @@
from .boomerang import Boomerang
from .fhn import FHNRS
from .wererabbit import WereRabbit
__all__ = ["WereRabbit", "FHNRS", "Boomerang"]

View File

@@ -0,0 +1,171 @@
from typing import Any, Dict, Tuple
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
from jaxtyping import Array, DTypeLike, Float
class Boomerang(eqx.Module):
rtol: float = eqx.field(static=True)
atol: float = eqx.field(static=True)
u0: float = eqx.field(static=True)
v0: float = eqx.field(static=True)
alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio
beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)
gamma: float = eqx.field(static=True) # coupling coefficient
rho: float = eqx.field(static=True) # tanh steepness
sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)
dtype: DTypeLike = eqx.field(static=True)
def __init__(
self,
*,
atol: float = 1e-6,
rtol: float = 1e-4,
alpha: float = 0.0129,
beta: float = 15.6,
gamma: float = 0.26,
rho: float = 30.0,
sigma: float = 0.6,
dtype: DTypeLike = jnp.float32,
):
r"""Initialize the WereRabbit neuron model.
Args:
key: JAX random key for weight initialization.
n_neurons: Number of neurons in this layer.
in_size: Number of input connections (excluding recurrent connections).
wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
rtol: Relative tolerance for the spiking fixpoint calculation.
atol: Absolute tolerance for the spiking fixpoint calculation.
alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
gamma: Coupling parameter $\gamma = 26e^{-2}$
rho: Steepness of the tanh function $\rho$ (default: 5)
sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
wlim: Limit for weight initialization. If None, uses init_weights.
wmean: Mean value for weight initialization.
init_weights: Optional initial weight values. If None, weights are randomly initialized.
fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').
dtype: Data type for arrays (default: float32).
"""
self.dtype = dtype
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.rho = rho
self.sigma = sigma
self.rtol = rtol
self.atol = atol
def fn(y, _):
return self.vector_field(y[0], y[1])
solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.3), jnp.array(0.3))
u0, v0 = optx.root_find(fn, solver, y0).value
self.u0 = u0.item()
self.v0 = v0.item()
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
"""Initialize the neuron state variables.
Args:
n_neurons: Number of neurons to initialize.
Returns:
Initial state array of shape (neurons, 3) containing [u, v],
where u and v are the predator/prey membrane voltages.
"""
u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)
v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)
x = jnp.stack([u, v], axis=1)
return x
def vector_field(
self, u: Float[Array, "..."], v: Float[Array, "..."]
) -> Tuple[Float[Array, "..."], Float[Array, "..."]]:
alpha = self.alpha
beta = self.beta
gamma = self.gamma
sigma = self.sigma
rho = self.rho
z = jax.nn.tanh(rho * (v - u))
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z
return du, dv
def dynamics(
self,
t: float,
y: Float[Array, "neurons 2"],
args: Dict[str, Any],
) -> Float[Array, "neurons 2"]:
"""Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics
Args:
t: Current simulation time (unused but required by framework).
y: State array of shape (neurons, 2) containing [u, v].
args: Additional arguments (unused but required by framework).
Returns:
Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
"""
u = y[:, 0]
v = y[:, 1]
du, dv = self.vector_field(u, v)
dxdt = jnp.stack([du, dv], axis=1)
return dxdt
def spike_condition(
self,
t: float,
y: Float[Array, "neurons 2"],
**kwargs: Dict[str, Any],
) -> Float[Array, " neurons"]:
"""Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO:
`has_spiked` is use to the system don't detect a continuos
spike when reach a fixpoint.
Args:
t: Current simulation time (unused but required by the framework).
y: State array of shape (neurons, 3) containing [u, v, has_spiked].
**kwargs: Additional keyword arguments (unused).
Returns:
Spike condition array of shape (neurons,). Positive values indicate spike.
"""
_atol = self.atol
_rtol = self.rtol
_norm = optx.rms_norm
vf = self.dynamics(t, y, {})
@jax.vmap
def calculate_norm(vf, y):
return _atol + _rtol * _norm(y) - _norm(vf)
base_cond = calculate_norm(vf, y).repeat(2)
return base_cond

View File

@@ -0,0 +1,235 @@
from typing import Any, Dict, Union
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, DTypeLike, Float
class FHNRS(eqx.Module):
r"""FitzHugh-Nagumo neuron model
Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by
Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast
and slow currents to produce oscillatory spiking behavior.
The dynamics are governed by:
$$
\begin{align}
C\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\
\frac{dv_{slow}}{dt} &= \frac{v - v_{slow}}{\tau_{slow}} \\
\frac{dI_{app}}{dt} &= -\frac{I_{app}}{\tau_{syn}}
\end{align}
$$
where the currents are:
- $I_{passive} = g_{max}(v - E_{rev})$
- $I_{fast} = a_{fast} \tanh(v - v_{off,fast})$
- $I_{slow} = a_{slow} \tanh(v_{slow} - v_{off,slow})$
References:
- Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.
Attributes:
reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.
gmax_pasive: Maximal conductance of the passive current.
Erev_pasive: Reversal potential for the passive current.
a_fast: Amplitude parameter for the fast current dynamics.
voff_fast: Voltage offset for the fast current activation.
tau_fast: Time constant for the fast current (typically zero for instantaneous).
a_slow: Amplitude parameter for the slow current dynamics.
voff_slow: Voltage offset for the slow current activation.
tau_slow: Time constant for the slow recovery variable.
vthr: Voltage threshold for spike generation.
C: Membrane capacitance.
tsyn: Synaptic time constant for input current decay.
weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).
"""
# Pasive parameters
gmax_pasive: float = eqx.field(static=True)
Erev_pasive: float = eqx.field(static=True)
# Fast current
a_fast: float = eqx.field(static=True)
voff_fast: float = eqx.field(static=True)
tau_fast: float = eqx.field(static=True)
# Slow current
a_slow: float = eqx.field(static=True)
voff_slow: float = eqx.field(static=True)
tau_slow: float = eqx.field(static=True)
# Neuron threshold
vthr: float = eqx.field(static=True)
C: float = eqx.field(static=True, default=1.0)
# Input synaptic time constant
tsyn: float = eqx.field(static=True)
dtype: DTypeLike = eqx.field(static=True)
def __init__(
self,
*,
tsyn: Union[int, float, jnp.ndarray] = 1.0,
C: Union[int, float, jnp.ndarray] = 1.0,
gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,
Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,
a_fast: Union[int, float, jnp.ndarray] = -2.0,
voff_fast: Union[int, float, jnp.ndarray] = 0.0,
tau_fast: Union[int, float, jnp.ndarray] = 0.0,
a_slow: Union[int, float, jnp.ndarray] = 2.0,
voff_slow: Union[int, float, jnp.ndarray] = 0.0,
tau_slow: Union[int, float, jnp.ndarray] = 50.0,
vthr: Union[int, float, jnp.ndarray] = 2.0,
dtype: DTypeLike = jnp.float32,
):
"""Initialize the FitzHugh-Nagumo neuron model.
Args:
tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.
C: Membrane capacitance. Can be scalar or per-neuron array.
gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.
Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.
a_fast: Amplitude of fast current. Can be scalar or per-neuron array.
voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.
tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
a_slow: Amplitude of slow current. Can be scalar or per-neuron array.
voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.
tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.
vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.
dtype: Data type for arrays (default: float32).
"""
self.dtype = dtype
self.tsyn = tsyn
self.C = C
self.gmax_pasive = gmax_pasive
self.Erev_pasive = Erev_pasive
self.a_fast = a_fast
self.voff_fast = voff_fast
self.tau_fast = tau_fast
self.a_slow = a_slow
self.voff_slow = voff_slow
self.tau_slow = tau_slow
self.vthr = vthr
def init_state(self, n_neurons: int) -> Float[Array, "neurons 3"]:
"""Initialize the neuron state variables.
Args:
n_neurons: Number of neurons to initialize.
Returns:
Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
where v is membrane voltage, v_slow is the slow recovery variable,
and i_app is the applied synaptic current.
"""
return jnp.zeros((n_neurons, 3), dtype=self.dtype)
def IV_inst(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
"""Compute instantaneous I-V relationship with fast and slow currents at rest.
Args:
v: Membrane voltage.
Vrest: Resting voltage for both fast and slow currents (default: 0).
Returns:
Total current at voltage v with both fast and slow currents evaluated at Vrest.
"""
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)
I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)
return I_pasive + I_fast + I_slow
def IV_fast(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
"""Compute I-V relationship with fast current at voltage v and slow current at rest.
Args:
v: Membrane voltage for passive and fast currents.
Vrest: Resting voltage for slow current (default: 0).
Returns:
Total current with fast dynamics responding to v and slow current at Vrest.
"""
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)
return I_pasive + I_fast + I_slow
def IV_slow(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
"""Compute steady-state I-V relationship with all currents at voltage v.
Args:
v: Membrane voltage for all currents.
Vrest: Unused parameter for API consistency (default: 0).
Returns:
Total steady-state current with all currents responding to v.
"""
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)
return I_pasive + I_fast + I_slow
def dynamics(
self,
t: float,
y: Float[Array, "neurons 3"],
args: Dict[str, Any],
) -> Float[Array, "neurons 3"]:
"""Compute time derivatives of the neuron state variables.
This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:
- dv/dt: Fast membrane voltage dynamics
- dv_slow/dt: Slow recovery variable dynamics
- di_app/dt: Synaptic current decay
Args:
t: Current simulation time (unused but required by framework).
y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
args: Additional arguments (unused but required by framework).
Returns:
Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
"""
v = y[:, 0]
v_slow = y[:, 1]
i_app = y[:, 2]
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)
i_sum = I_pasive + I_fast + I_slow
dv_dt = (i_app - i_sum) / self.C
dvslow_dt = (v - v_slow) / self.tau_slow
di_dt = -i_app / self.tsyn
return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)
def spike_condition(
self,
t: float,
y: Float[Array, "neurons 3"],
**kwargs: Dict[str, Any],
) -> Float[Array, " neurons"]:
"""Compute spike condition for event detection.
A spike is triggered when this function crosses zero (v >= vthr).
Args:
t: Current simulation time (unused but required by event detection).
y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
**kwargs: Additional keyword arguments (unused).
Returns:
Spike condition array of shape (neurons,). Positive values indicate v > vthr.
"""
return y[:, 0] - self.vthr

View File

@@ -0,0 +1,194 @@
from typing import Any, Dict
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
from jaxtyping import Array, DTypeLike, Float
class WereRabbit(eqx.Module):
r"""
WereRabbit Neuron Model
The WereRabbit model implements a predator-prey dynamic with bistable
switching behavior controlled by a "moon phase" parameter $z$.
The dynamics are governed by:
$$
\begin{align}
z &= tanh(\rho (u-v)) \\
\frac{du}{dt} &= z - z \alpha e^{\beta v} [1 + \gamma (0.5 - u)] - \sigma \\
\frac{dv}{dt} &= -z - z \alpha e^{\beta u} [1 + \gamma (0.5 - v)] - \sigma
\end{align}
$$
where $z$ represents the "moon phase" that switches the predator-prey roles.
Attributes:
alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
gamma: Coupling parameter $\gamma = 26e^{-2}$
rho: Steepness of the tanh function $\rho$ (default: 5)
sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
rtol: Relative tolerance for the spiking fixpoint calculation.
atol: Absolute tolerance for the spiking fixpoint calculation.
weight_u: Input weight for the predator.
weight_v: Input weight for the prey.
"""
dtype: DTypeLike = eqx.field(static=True)
rtol: float = eqx.field(static=True)
atol: float = eqx.field(static=True)
alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio
beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)
gamma: float = eqx.field(static=True) # coupling coefficient
rho: float = eqx.field(static=True) # tanh steepness
sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)
def __init__(
self,
*,
atol: float = 1e-3,
rtol: float = 1e-3,
alpha: float = 0.0129,
beta: float = 15.6,
gamma: float = 0.26,
rho: float = 5.0,
sigma: float = 0.6,
dtype: DTypeLike = jnp.float32,
):
r"""Initialize the WereRabbit neuron model.
Args:
rtol: Relative tolerance for the spiking fixpoint calculation.
atol: Absolute tolerance for the spiking fixpoint calculation.
alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
gamma: Coupling parameter $\gamma = 26e^{-2}$
rho: Steepness of the tanh function $\rho$ (default: 5)
sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
dtype: Data type for arrays (default: float32).
"""
self.dtype = dtype
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.rho = rho
self.sigma = sigma
self.rtol = rtol
self.atol = atol
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
"""Initialize the neuron state variables.
Args:
n_neurons: Number of neurons to initialize.
Returns:
Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
where u and v are the predator/prey membrane voltages, has_spiked is a
variable that is 1 whenever the neuron spike and 0 otherwise .
"""
x1 = jnp.zeros((n_neurons,), dtype=self.dtype)
x2 = jnp.zeros((n_neurons,), dtype=self.dtype)
return jnp.stack([x1, x2], axis=1)
def vector_field(self, y: Float[Array, "neurons 2"]) -> Float[Array, "neurons 2"]:
"""Compute vector field of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics
Args:
y: State array of shape (neurons, 2) containing [u, v].
Returns:
Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
"""
u = y[:, 0]
v = y[:, 1]
z = jax.nn.tanh(self.rho * (u - v))
du = (
z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))
- self.sigma
)
dv = (
z
* (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))
- self.sigma
)
dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)
du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)
return jnp.stack([du, dv], axis=1)
def dynamics(
self,
t: float,
y: Float[Array, "neurons 2"],
args: Dict[str, Any],
) -> Float[Array, "neurons 2"]:
"""Compute time derivatives of the neuron state variables.
This implements the WereRabbit dynamics
- du/dt: Predator dynamics
- dv/dt: WerePrey dynamics
Args:
t: Current simulation time (unused but required by framework).
y: State array of shape (neurons, 3) containing [u, v, has_spiked].
args: Additional arguments (unused but required by framework).
Returns:
Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
"""
dxdt = self.vector_field(y)
return dxdt
def spike_condition(
self,
t: float,
y: Float[Array, "neurons 2"],
**kwargs: Dict[str, Any],
) -> Float[Array, " neurons"]:
"""Compute spike condition for event detection.
A spike is triggered when the system reach to a fixpoint.
INFO:
`has_spiked` is use to the system don't detect a continuos
spike when reach a fixpoint.
Args:
t: Current simulation time (unused but required by the framework).
y: State array of shape (neurons, 3) containing [u, v, has_spiked].
**kwargs: Additional keyword arguments (unused).
Returns:
Spike condition array of shape (neurons,). Positive values indicate spike.
"""
_atol = self.atol
_rtol = self.rtol
_norm = optx.rms_norm
vf = self.dynamics(t, y, {})
@jax.vmap
def calculate_norm(vf, y):
return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])
base_cond = calculate_norm(vf, y)
return base_cond

64
src/felice/solver.py Normal file
View File

@@ -0,0 +1,64 @@
from typing import Optional
import equinox as eqx
import jax
from diffrax import AbstractSolver, AbstractTerm
from diffrax._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, Y
from diffrax._solution import RESULTS
from diffrax._solver.base import _SolverState
from jaxtyping import PyTree
class ClipSolver(eqx.Module):
solver: AbstractSolver
def __getattr__(self, name):
return getattr(self.solver, name)
def step(
self,
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
"""Make a single step of the solver.
Each step is made over the specified interval $[t_0, t_1]$.
**Arguments:**
- `terms`: The PyTree of terms representing the vector fields and controls.
- `t0`: The start of the interval that the step is made over.
- `t1`: The end of the interval that the step is made over.
- `y0`: The current value of the solution at `t0`.
- `args`: Any extra arguments passed to the vector field.
- `solver_state`: Any evolving state for the solver itself, at `t0`.
- `made_jump`: Whether there was a discontinuity in the vector field at `t0`.
Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there
are no jumps and for efficiency re-use information between steps; this
indicates that a jump has just occurred and this assumption is not true.
**Returns:**
A tuple of several objects:
- The value of the solution at `t1`.
- A local error estimate made during the step. (Used by adaptive step size
controllers to change the step size.) May be `None` if no estimate was
made.
- Some dictionary of information that is passed to the solver's interpolation
routine to calculate dense output. (Used with `SaveAt(ts=...)` or
`SaveAt(dense=...)`.)
- The value of the solver state at `t1`.
- An integer (corresponding to `diffrax.RESULTS`) indicating whether the step
happened successfully, or if (unusually) it failed for some reason.
"""
y1, y_error, dense_info, solver_state, result = self.solver.step(
terms, t0, t1, y0, args, solver_state, made_jump
)
y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)
return y1_clipped, y_error, dense_info, solver_state, result