mirror of
https://github.com/bics-rug/felice-models.git
synced 2026-04-24 16:58:41 +02:00
Initial commit
Co-authored-by: Aradhana Dube <a.dube@rug.nl> Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl> Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl> Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
This commit is contained in:
0
src/felice/__init__.py
Normal file
0
src/felice/__init__.py
Normal file
0
src/felice/datasets/__init__.py
Normal file
0
src/felice/datasets/__init__.py
Normal file
172
src/felice/datasets/braille.py
Normal file
172
src/felice/datasets/braille.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
from urllib.request import urlopen
|
||||
from zipfile import ZipFile
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
letters = [
|
||||
"Space",
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
"I",
|
||||
"J",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"O",
|
||||
"P",
|
||||
"Q",
|
||||
"R",
|
||||
"S",
|
||||
"T",
|
||||
"U",
|
||||
"V",
|
||||
"W",
|
||||
"X",
|
||||
"Y",
|
||||
"Z",
|
||||
]
|
||||
|
||||
|
||||
def load(
|
||||
thr: int = 1,
|
||||
taxels: Optional[Sequence[int]] = None,
|
||||
custom_letters: Optional[Sequence[str]] = None,
|
||||
):
|
||||
file_name = Path(f"./braille/data_braille_letters_th_{thr}.pkl")
|
||||
if not file_name.exists():
|
||||
resp = urlopen(
|
||||
"https://zenodo.org/records/7050094/files/reading_braille_data.zip"
|
||||
)
|
||||
with ZipFile(BytesIO(resp.read())) as zObject:
|
||||
zObject.extract(f"data_braille_letters_th_{thr}.pkl", path="./braille")
|
||||
|
||||
data_dict = pd.read_pickle(file_name)
|
||||
|
||||
# Extract data
|
||||
data_list: list[np.ndarray] = []
|
||||
label_list: list[int] = []
|
||||
|
||||
nchan = len(data_dict["events"][1]) # number of channels per sensor
|
||||
|
||||
max_events = 0
|
||||
for i, sample in enumerate(data_dict["events"]):
|
||||
for taxel in range(len(sample)):
|
||||
for event_type in range(len(sample[taxel])):
|
||||
events = sample[taxel][event_type]
|
||||
max_events = len(events) if len(events) > max_events else max_events
|
||||
|
||||
for i, sample in enumerate(data_dict["events"]):
|
||||
events_array = np.full([nchan, max_events, 2], np.inf)
|
||||
for taxel in range(len(sample)):
|
||||
# loop over On and Off channels
|
||||
for event_type in range(len(sample[taxel])):
|
||||
events = sample[taxel][event_type]
|
||||
if events:
|
||||
events_array[taxel, : len(events), event_type] = events # ms
|
||||
|
||||
if taxels is not None:
|
||||
events_array = np.reshape(
|
||||
np.transpose(events_array, (0, 2, 1))[taxels, :, :],
|
||||
(-1, events_array.shape[1]),
|
||||
)
|
||||
selected_chans = 2 * len(taxels)
|
||||
else:
|
||||
events_array = np.reshape(
|
||||
np.transpose(events_array, (0, 2, 1)), (-1, events_array.shape[1])
|
||||
)
|
||||
selected_chans = 2 * nchan
|
||||
|
||||
lbl = data_dict["letter"][i]
|
||||
if custom_letters is not None:
|
||||
if lbl in custom_letters:
|
||||
data_list.append(events_array)
|
||||
label_list.append(custom_letters.index(lbl))
|
||||
else:
|
||||
data_list.append(events_array)
|
||||
label_list.append(letters.index(lbl))
|
||||
|
||||
data = np.stack(data_list) * 100 # To ms
|
||||
labels = np.stack(label_list)
|
||||
nb_outputs = len(np.unique(labels))
|
||||
|
||||
x_train, x_test, y_train, y_test = train_test_split(
|
||||
data, labels, test_size=0.30, shuffle=True, stratify=labels
|
||||
)
|
||||
|
||||
x_test, x_validation, y_test, y_validation = train_test_split(
|
||||
x_test, y_test, test_size=0.33, shuffle=True, stratify=y_test
|
||||
)
|
||||
|
||||
trainset = (jnp.asarray(x_train), jnp.asarray(y_train))
|
||||
testset = (jnp.asarray(x_test), jnp.asarray(y_test))
|
||||
valset = (jnp.asarray(x_validation), jnp.asarray(y_validation))
|
||||
|
||||
return trainset, testset, valset, selected_chans, nb_outputs
|
||||
|
||||
|
||||
def load_raw(
|
||||
taxels: Optional[Sequence[int]] = None,
|
||||
custom_letters: Optional[Sequence[str]] = None,
|
||||
):
|
||||
file_name = Path("./braille/data_braille_letters_digits.pkl")
|
||||
if not file_name.exists():
|
||||
resp = urlopen(
|
||||
"https://zenodo.org/records/7050094/files/reading_braille_data.zip"
|
||||
)
|
||||
with ZipFile(BytesIO(resp.read())) as zObject:
|
||||
zObject.extract("data_braille_letters_digits.pkl", path="./braille")
|
||||
|
||||
data_dict = pd.read_pickle(file_name)
|
||||
|
||||
# Extract data
|
||||
data_list: list[np.ndarray] = []
|
||||
label_list: list[int] = []
|
||||
|
||||
nchan = data_dict["taxel_data"][0].shape[1] # number of channels per sensor
|
||||
|
||||
for i, sample in enumerate(data_dict["taxel_data"]):
|
||||
if taxels is not None:
|
||||
sample = sample[:, taxels]
|
||||
selected_chans = len(taxels)
|
||||
else:
|
||||
selected_chans = nchan
|
||||
|
||||
lbl = data_dict["letter"][i]
|
||||
if custom_letters is not None:
|
||||
if lbl in custom_letters:
|
||||
data_list.append(sample)
|
||||
label_list.append(custom_letters.index(lbl))
|
||||
else:
|
||||
data_list.append(sample)
|
||||
label_list.append(letters.index(lbl))
|
||||
|
||||
data = np.stack(data_list) # To ms
|
||||
labels = np.stack(label_list)
|
||||
nb_outputs = len(np.unique(labels))
|
||||
|
||||
x_train, x_test, y_train, y_test = train_test_split(
|
||||
data, labels, test_size=0.30, shuffle=True, stratify=labels
|
||||
)
|
||||
|
||||
x_test, x_validation, y_test, y_validation = train_test_split(
|
||||
x_test, y_test, test_size=0.33, shuffle=True, stratify=y_test
|
||||
)
|
||||
|
||||
trainset = (jnp.asarray(x_train), jnp.asarray(y_train))
|
||||
testset = (jnp.asarray(x_test), jnp.asarray(y_test))
|
||||
valset = (jnp.asarray(x_validation), jnp.asarray(y_validation))
|
||||
|
||||
return trainset, testset, valset, selected_chans, nb_outputs
|
||||
125
src/felice/datasets/reasoning.py
Normal file
125
src/felice/datasets/reasoning.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jaxtyping import Array, PRNGKeyArray
|
||||
|
||||
|
||||
class ReasoningDataset:
|
||||
"""
|
||||
Task types:
|
||||
1. Simple comparison: IF A > B THEN x ELSE y
|
||||
2. Accumulated conditions: SET A, ADD B, IF SUM > threshold THEN x ELSE y
|
||||
"""
|
||||
|
||||
# Token vocabulary
|
||||
VOCAB = {
|
||||
"PAD": 0,
|
||||
"SET_A": 1,
|
||||
"SET_B": 2,
|
||||
"SET_C": 3,
|
||||
"IF_A>B": 4,
|
||||
"IF_A<B": 5,
|
||||
"IF_SUM>": 6,
|
||||
"THEN": 7,
|
||||
"ELSE": 8,
|
||||
"QUERY": 9,
|
||||
"ADD": 10,
|
||||
}
|
||||
NUM_OFFSET: int = 11
|
||||
VOCAB_SIZE: int = 31
|
||||
NUM_OUTPUT: int = 16
|
||||
|
||||
def generate_simple_comparison(self, key: PRNGKeyArray) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Generate: [SET_A, a, SET_B, b, IF_A>B, THEN, x, ELSE, y, QUERY]
|
||||
Target at QUERY position: x if a > b else y
|
||||
"""
|
||||
keys = jax.random.split(key, 4)
|
||||
|
||||
a = jax.random.randint(keys[0], (), 0, self.NUM_OUTPUT - 1)
|
||||
b = jax.random.randint(keys[1], (), 0, self.NUM_OUTPUT - 1)
|
||||
x = jax.random.randint(keys[2], (), 0, self.NUM_OUTPUT - 1)
|
||||
y = jax.random.randint(keys[3], (), 0, self.NUM_OUTPUT - 1)
|
||||
|
||||
# TODO: Generate other sequences
|
||||
input_seq = jnp.array(
|
||||
[
|
||||
self.VOCAB["SET_A"],
|
||||
self.NUM_OFFSET + a,
|
||||
self.VOCAB["SET_B"],
|
||||
self.NUM_OFFSET + b,
|
||||
self.VOCAB["IF_A>B"],
|
||||
self.VOCAB["THEN"],
|
||||
self.NUM_OFFSET + x,
|
||||
self.VOCAB["ELSE"],
|
||||
self.NUM_OFFSET + y,
|
||||
self.VOCAB["QUERY"],
|
||||
]
|
||||
)
|
||||
|
||||
result = jnp.where(a > b, x, y)
|
||||
target = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, result])
|
||||
mask = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
|
||||
|
||||
return input_seq, target, mask
|
||||
|
||||
def generate_accumulation_condition(self, key: PRNGKeyArray) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Generate: [SET_A, a, ADD, b, ADD, c, IF_SUM>, threshold, THEN, x, ELSE, y, QUERY]
|
||||
Requires accumulating a + b + c and comparing to threshold.
|
||||
"""
|
||||
keys = jax.random.split(key, 6)
|
||||
|
||||
a = jax.random.randint(keys[0], (), 0, 10)
|
||||
b = jax.random.randint(keys[1], (), 0, 10)
|
||||
c = jax.random.randint(keys[2], (), 0, 10)
|
||||
threshold = jax.random.randint(keys[3], (), 5, 25)
|
||||
x = jax.random.randint(keys[4], (), 0, 15)
|
||||
y = jax.random.randint(keys[5], (), 0, 15)
|
||||
|
||||
# TODO: Generate other sequences
|
||||
input_seq = jnp.array(
|
||||
[
|
||||
self.VOCAB["SET_A"],
|
||||
self.NUM_OFFSET + a,
|
||||
self.VOCAB["ADD"],
|
||||
self.NUM_OFFSET + b,
|
||||
self.VOCAB["ADD"],
|
||||
self.NUM_OFFSET + c,
|
||||
self.VOCAB["IF_SUM>"],
|
||||
self.NUM_OFFSET + threshold,
|
||||
self.VOCAB["THEN"],
|
||||
self.NUM_OFFSET + x,
|
||||
self.VOCAB["ELSE"],
|
||||
self.NUM_OFFSET + y,
|
||||
self.VOCAB["QUERY"],
|
||||
]
|
||||
)
|
||||
|
||||
total = a + b + c
|
||||
result = jnp.where(total > threshold, x, y)
|
||||
|
||||
target = jnp.zeros(13, dtype=jnp.int32).at[-1].set(result)
|
||||
mask = jnp.zeros(13, dtype=jnp.int32).at[-1].set(1)
|
||||
|
||||
return input_seq, target, mask
|
||||
|
||||
def generate_batch(
|
||||
self, key: PRNGKeyArray, batch_size: int, task_type: str = "simple"
|
||||
) -> Tuple[Array, Array, Array]:
|
||||
keys = jax.random.split(key, batch_size)
|
||||
|
||||
if task_type == "simple":
|
||||
gen_fn = self.generate_simple_comparison
|
||||
else: # accumulation
|
||||
gen_fn = self.generate_accumulation_condition
|
||||
|
||||
inputs, targets, masks = [], [], []
|
||||
for k in keys:
|
||||
inp, tgt, msk = gen_fn(k)
|
||||
inputs.append(inp)
|
||||
targets.append(tgt)
|
||||
masks.append(msk)
|
||||
|
||||
return jnp.stack(inputs), jnp.stack(targets), jnp.stack(masks)
|
||||
45
src/felice/networks/__init__.py
Normal file
45
src/felice/networks/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Generic, Protocol, TypeVar
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.random as jrandom
|
||||
from jaxtyping import Array, Float, PRNGKeyArray
|
||||
|
||||
from .implicit import Boomerang, Implicit
|
||||
from .ssm import Mamba
|
||||
|
||||
__all__ = ["Mamba", "Implicit", "Boomerang", "SequenceClassifier"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HasSequential(Protocol):
|
||||
def sequential(self, x): ...
|
||||
|
||||
|
||||
class SequenceClassifier(eqx.Module, Generic[T]):
|
||||
embedding: eqx.nn.Embedding
|
||||
model: eqx.Module
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int,
|
||||
model_class: T,
|
||||
key=PRNGKeyArray,
|
||||
**model_kwargs: dict,
|
||||
):
|
||||
keys = jrandom.split(key, 2)
|
||||
self.embedding = eqx.nn.Embedding(vocab_size, d_model, key=keys[0])
|
||||
self.model = model_class(d_model=d_model, key=keys[1], **model_kwargs)
|
||||
|
||||
def __call__(self, input_ids: Float[Array, " seq"]) -> Float[Array, "seq d_model"]:
|
||||
x = jax.vmap(self.embedding)(input_ids)
|
||||
y = self.model(x)
|
||||
return y
|
||||
|
||||
def sequential(
|
||||
self: "SequenceClassifier[HasSequential]", input_ids: Float[Array, " seq"]
|
||||
) -> Float[Array, "seq d_model"]:
|
||||
x = jax.vmap(self.embedding)(input_ids)
|
||||
return self.model.sequential(x)
|
||||
4
src/felice/networks/implicit/__init__.py
Normal file
4
src/felice/networks/implicit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import Implicit
|
||||
from .boomerang import Boomerang
|
||||
|
||||
__all__ = ["Implicit", "Boomerang"]
|
||||
210
src/felice/networks/implicit/base.py
Normal file
210
src/felice/networks/implicit/base.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import diffrax as dfx
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrandom
|
||||
import optimistix as optx
|
||||
from diffrax._custom_types import RealScalarLike
|
||||
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
|
||||
|
||||
from ..utils import binary_op
|
||||
|
||||
|
||||
def _depth_step(_, z_prev, args):
|
||||
x, model = args
|
||||
|
||||
# ── Eq (6): h_t^{(s)} = Λ(z_t^{(s-1)}, x_t) * h_{t-1}^{(s)} + u(z_t^{(s-1)}, x_t)
|
||||
lambda_vals = jax.vmap(model.compute_lambda)(z_prev, x)
|
||||
u_vals = jax.vmap(model.compute_u)(z_prev, x)
|
||||
_, h = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
|
||||
|
||||
# ── Eq (7): z_t^{(s)} = f_θ(z_t^{(s-1)}, h_{t-1}^{(s)}, x_t)
|
||||
# z depends on h_{t-1}^{(s)}, i.e. the hidden state of the
|
||||
# PREVIOUS token h_{t-1}^{(s)}, NOT the just-computed h_t^{(s)}.
|
||||
h0 = jnp.zeros((1, model.d_state))
|
||||
h = jnp.concatenate([h0, h[:-1]], axis=0)
|
||||
|
||||
dz = jax.vmap(model.f_theta)(z_prev, h, x)
|
||||
|
||||
return dz
|
||||
|
||||
|
||||
Normalizer = Callable[[PyTree[Array]], RealScalarLike]
|
||||
|
||||
|
||||
class Implicit(eqx.Module):
|
||||
d_model: int = eqx.field(static=True)
|
||||
d_state: int = eqx.field(static=True)
|
||||
d_inner: int = eqx.field(static=True)
|
||||
dt: float = eqx.field(static=True)
|
||||
max_time: float = eqx.field(static=True)
|
||||
max_iters: int | None = eqx.field(static=True)
|
||||
rtol: float = eqx.field(static=True)
|
||||
atol: float = eqx.field(static=True)
|
||||
norm: Normalizer = eqx.field(static=True)
|
||||
|
||||
adjoint: dfx.AbstractAdjoint
|
||||
solver: dfx.AbstractSolver
|
||||
|
||||
f_net: eqx.nn.Linear
|
||||
lambda_net: eqx.nn.Linear # Λ: maps (z, x) → decay factor (diagonal)
|
||||
u_net: eqx.nn.Linear # u: maps (z, x) → input state
|
||||
out_net: eqx.nn.Linear # Output: maps (z, h) → y
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_state: int = 16,
|
||||
d_inner: int = 32,
|
||||
dt: float = 1.0,
|
||||
max_iters: int | None = 100,
|
||||
max_time: float = 100,
|
||||
solver: Optional[dfx.AbstractSolver] = None,
|
||||
adjoint: Optional[dfx.AbstractAdjoint] = None,
|
||||
rtol: float = 1e-4,
|
||||
atol: float = 1e-3,
|
||||
norm: Normalizer = optx.rms_norm,
|
||||
*,
|
||||
key: PRNGKeyArray,
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_inner = d_inner
|
||||
self.dt = dt
|
||||
self.max_time = max_time
|
||||
self.max_iters = max_iters
|
||||
self.solver = solver if solver is not None else dfx.Tsit5()
|
||||
self.adjoint = adjoint if adjoint is not None else dfx.ImplicitAdjoint()
|
||||
self.rtol = rtol
|
||||
self.atol = atol
|
||||
self.norm = norm
|
||||
|
||||
keys = jrandom.split(key, 4)
|
||||
|
||||
self.f_net = eqx.nn.Linear(
|
||||
d_inner + d_state + d_model,
|
||||
d_inner,
|
||||
key=keys[0],
|
||||
)
|
||||
|
||||
self.lambda_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[1])
|
||||
self.u_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[2])
|
||||
self.out_net = eqx.nn.Linear(d_inner, d_model, key=keys[3])
|
||||
|
||||
def compute_lambda(
|
||||
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
|
||||
) -> Float[Array, " d_state"]:
|
||||
zx = jnp.concatenate([z, x], axis=-1)
|
||||
# Sigmoid to keep in (0, 1) for stability
|
||||
return jax.nn.sigmoid(self.lambda_net(zx))
|
||||
|
||||
def compute_u(
|
||||
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
|
||||
) -> Float[Array, " d_state"]:
|
||||
zx = jnp.concatenate([z, x], axis=-1)
|
||||
return self.u_net(zx)
|
||||
|
||||
def f_theta(
|
||||
self,
|
||||
z: Float[Array, " d_inner"],
|
||||
h: Float[Array, " d_state"],
|
||||
x: Float[Array, " d_model"],
|
||||
) -> Float[Array, " d_inner"]:
|
||||
"""f_θ(z, h, x) → z - the implicit function."""
|
||||
zhx = jnp.concatenate([z, h, x])
|
||||
dz = jax.nn.silu(self.f_net(zhx)) - z
|
||||
return dz
|
||||
|
||||
def get_z(
|
||||
self,
|
||||
x: Float[Array, "seq d_model"],
|
||||
t0: float = 0,
|
||||
t1: float | None = None,
|
||||
num_points: int = 100,
|
||||
) -> Float[Array, "seq d_model"]:
|
||||
seq_len, _ = x.shape
|
||||
t1 = t1 if t1 is not None else self.max_time
|
||||
|
||||
z0 = jnp.zeros((seq_len, self.d_inner))
|
||||
terms = dfx.ODETerm(_depth_step)
|
||||
sol = dfx.diffeqsolve(
|
||||
terms,
|
||||
self.solver,
|
||||
t0=0.0,
|
||||
t1=t1,
|
||||
dt0=self.dt,
|
||||
y0=z0,
|
||||
args=(x, self),
|
||||
max_steps=self.max_iters,
|
||||
saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, num_points)),
|
||||
)
|
||||
|
||||
return sol.ts, sol.ys
|
||||
|
||||
def sequential(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
def scan_fn(h, x_t):
|
||||
|
||||
def depth_step(_, z_prev, args):
|
||||
h, x, model = args
|
||||
z_s = model.f_theta(z_prev, h, x)
|
||||
|
||||
return z_s
|
||||
|
||||
z0 = jnp.zeros((self.d_inner,))
|
||||
cond_fn = dfx.steady_state_event(rtol=1e-4, atol=1e-3, norm=optx.rms_norm)
|
||||
event = dfx.Event(cond_fn)
|
||||
terms = dfx.ODETerm(depth_step)
|
||||
sol = dfx.diffeqsolve(
|
||||
terms,
|
||||
self.solver,
|
||||
t0=0.0,
|
||||
t1=self.max_time,
|
||||
dt0=self.dt,
|
||||
y0=z0,
|
||||
args=(h, x_t, self),
|
||||
max_steps=self.max_iters,
|
||||
event=event,
|
||||
adjoint=self.adjoint,
|
||||
)
|
||||
z_star = sol.ys[-1]
|
||||
|
||||
# Compute new hidden state
|
||||
lambda_val = self.compute_lambda(z_star, x_t) # (d_state,)
|
||||
u_val = self.compute_u(z_star, x_t) # (d_state,)
|
||||
|
||||
h_new = lambda_val * h + u_val
|
||||
|
||||
y = self.out_net(z_star)
|
||||
|
||||
return h_new, y
|
||||
|
||||
h_init = jnp.zeros(self.d_state)
|
||||
_, y = jax.lax.scan(scan_fn, h_init, x)
|
||||
|
||||
return y
|
||||
|
||||
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
seq_len, _ = x.shape
|
||||
|
||||
z0 = jnp.zeros((seq_len, self.d_inner))
|
||||
cond_fn = dfx.steady_state_event(rtol=1e-4, atol=1e-3, norm=optx.rms_norm)
|
||||
event = dfx.Event(cond_fn)
|
||||
terms = dfx.ODETerm(_depth_step)
|
||||
sol = dfx.diffeqsolve(
|
||||
terms,
|
||||
self.solver,
|
||||
t0=0.0,
|
||||
t1=self.max_time,
|
||||
dt0=self.dt,
|
||||
y0=z0,
|
||||
args=(x, self),
|
||||
max_steps=self.max_iters,
|
||||
event=event,
|
||||
adjoint=self.adjoint,
|
||||
)
|
||||
z_star = sol.ys[-1]
|
||||
y_star = jax.vmap(self.out_net)(z_star)
|
||||
|
||||
return y_star
|
||||
346
src/felice/networks/implicit/boomerang.py
Normal file
346
src/felice/networks/implicit/boomerang.py
Normal file
@@ -0,0 +1,346 @@
|
||||
from typing import Optional
|
||||
|
||||
import diffrax as dfx
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrandom
|
||||
import optimistix as optx
|
||||
from jaxtyping import Array, Float, PRNGKeyArray
|
||||
|
||||
from ..utils import binary_op
|
||||
from .base import Implicit, Normalizer
|
||||
|
||||
|
||||
class ImplicitBoomerang(eqx.Module):
|
||||
d_model: int = eqx.field(static=True)
|
||||
d_state: int = eqx.field(static=True)
|
||||
d_inner: int = eqx.field(static=True)
|
||||
max_iters: int = eqx.field(static=True)
|
||||
tol: float = eqx.field(static=True)
|
||||
dt: float = eqx.field(static=True)
|
||||
with_thr: bool = eqx.field(static=True)
|
||||
debug: bool = eqx.field(static=True)
|
||||
|
||||
f_net: eqx.nn.Linear
|
||||
lambda_net: eqx.nn.Linear # Λ: maps (z, x) → decay factor (diagonal)
|
||||
u_net: eqx.nn.Linear # u: maps (z, x) → input state
|
||||
out_net: eqx.nn.Linear # Output: maps (z, h) → y
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_state: int = 16,
|
||||
d_inner: int = 32,
|
||||
max_iters: int = 20,
|
||||
tol: float = 1e-5,
|
||||
dt: float = 1e-3,
|
||||
with_thr: bool = True,
|
||||
debug: bool = False,
|
||||
*,
|
||||
key: PRNGKeyArray,
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_inner = d_inner
|
||||
self.max_iters = max_iters
|
||||
self.tol = tol
|
||||
self.dt = dt
|
||||
self.with_thr = with_thr
|
||||
self.debug = debug
|
||||
|
||||
keys = jrandom.split(key, 6)
|
||||
|
||||
self.f_net = eqx.nn.Linear(
|
||||
d_state + d_model,
|
||||
d_inner,
|
||||
key=keys[1],
|
||||
)
|
||||
|
||||
self.lambda_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[2])
|
||||
self.u_net = eqx.nn.Linear(d_inner + d_model, d_state, key=keys[3])
|
||||
self.out_net = eqx.nn.Linear(d_inner + d_state, d_model, key=keys[4])
|
||||
|
||||
def compute_lambda(
|
||||
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
|
||||
) -> Float[Array, " d_state"]:
|
||||
zx = jnp.concatenate([z, x], axis=-1)
|
||||
# Sigmoid to keep in (0, 1) for stability
|
||||
return jax.nn.sigmoid(jax.vmap(self.lambda_net)(zx))
|
||||
|
||||
def compute_u(
|
||||
self, z: Float[Array, " d_inner"], x: Float[Array, " d_model"]
|
||||
) -> Float[Array, " d_state"]:
|
||||
zx = jnp.concatenate([z, x], axis=-1)
|
||||
return jax.vmap(self.u_net)(zx)
|
||||
|
||||
def f_theta(
|
||||
self,
|
||||
z: Float[Array, " d_inner"],
|
||||
h: Float[Array, " d_state"],
|
||||
x: Float[Array, " d_model"],
|
||||
) -> Float[Array, " d_inner"]:
|
||||
"""f_θ(z, h, x) → z - the implicit function."""
|
||||
hx = jnp.concatenate([h, x])
|
||||
|
||||
alpha_sigma = jnp.split(self.f_net(hx), 2)
|
||||
|
||||
rho = 30.0
|
||||
alpha = jax.nn.sigmoid(alpha_sigma[0])
|
||||
beta = 15.6
|
||||
gamma = 0.26
|
||||
sigma = jax.nn.sigmoid(alpha_sigma[1])
|
||||
u, v = jnp.split(z, 2)
|
||||
|
||||
def true_fn(u, v):
|
||||
return jax.nn.tanh(rho * (v - u))
|
||||
|
||||
def false_fn(u, v):
|
||||
return jnp.ones_like(u)
|
||||
|
||||
thresh = jax.lax.cond(self.with_thr, true_fn, false_fn, u, v)
|
||||
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * thresh
|
||||
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * thresh
|
||||
dz = jnp.concat([du, dv])
|
||||
z = z + self.dt * dz
|
||||
|
||||
# kk = 0.68 # fixed by tech
|
||||
# Ut = 0.025 # temperature dependent
|
||||
# I_r0 = 0.9
|
||||
# x1, x2 = jnp.split(z, 2)
|
||||
|
||||
# alpha = 0.000129 + (0.0129 - 0.000129) * jax.nn.sigmoid(
|
||||
# alpha_sigma[0]
|
||||
# ) # The circuit will get directly the current
|
||||
# Ia = jax.nn.tanh(alpha_sigma[1]) * 0.6
|
||||
|
||||
# x3 = 1 # (np.tanh(20 * (x2 - x1))) #smoother transition
|
||||
# dx1 = 2.3 * (
|
||||
# 1 - (alpha * jnp.exp(kk * x2 / Ut)) * (1 - I_r0 * (0.3 - x1)) + Ia * x3
|
||||
# )
|
||||
# dx2 = 2.3 * (
|
||||
# -1 + (alpha * jnp.exp(kk * x1 / Ut)) * (1 + I_r0 * (0.3 - x2)) + Ia * x3
|
||||
# )
|
||||
# dz = jnp.concat([dx1, dx2])
|
||||
# z = z + self.dt * dz
|
||||
|
||||
return z
|
||||
|
||||
def get_z(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
seq_len, _ = x.shape
|
||||
|
||||
def body_fn(state, _):
|
||||
z = state
|
||||
|
||||
lambda_vals = self.compute_lambda(z, x)
|
||||
u_vals = self.compute_u(z, x)
|
||||
|
||||
_, h = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
|
||||
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
|
||||
|
||||
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
|
||||
|
||||
return z_new, z_new
|
||||
|
||||
# Initialize
|
||||
z_init = jnp.zeros((seq_len, self.d_inner))
|
||||
|
||||
# Run fixed-point iteration
|
||||
_, z = jax.lax.scan(
|
||||
body_fn,
|
||||
z_init,
|
||||
None,
|
||||
length=self.max_iters,
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
def sequential(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
|
||||
def scan_fn(h, x_t):
|
||||
z_init = jnp.zeros((self.d_inner,))
|
||||
z_prev_init = jnp.full((self.d_inner,), jnp.inf)
|
||||
|
||||
def cond_fn(state):
|
||||
i, z, z_prev = state
|
||||
converged = (
|
||||
jnp.linalg.norm(z - z_prev) / jnp.linalg.norm(z_prev) < 0.001
|
||||
)
|
||||
return (i < self.max_iters) & ~converged
|
||||
|
||||
def body_fn(state):
|
||||
i, z, _ = state
|
||||
z_new = self.f_theta(z, h, x_t)
|
||||
return (i + 1, z_new, z)
|
||||
|
||||
# Run fixed-point iteration
|
||||
if self.debug:
|
||||
|
||||
def scan_fn(state, _):
|
||||
new_state = body_fn(state)
|
||||
return new_state, new_state[1]
|
||||
|
||||
_, z_star_debug = jax.lax.scan(
|
||||
scan_fn, (0, z_init, z_prev_init), None, length=self.max_iters
|
||||
)
|
||||
z_star = z_star_debug[-1]
|
||||
else:
|
||||
_, z_star, _ = eqx.internal.while_loop(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
(0, z_init, z_prev_init),
|
||||
max_steps=self.max_iters,
|
||||
kind="bounded",
|
||||
)
|
||||
|
||||
# Compute new hidden state
|
||||
lambda_val = self.compute_lambda(
|
||||
z_star[jnp.newaxis, :], x_t[jnp.newaxis, :]
|
||||
)[0] # (d_state,)
|
||||
u_val = self.compute_u(z_star[jnp.newaxis, :], x_t[jnp.newaxis, :])[
|
||||
0
|
||||
] # (d_state,)
|
||||
|
||||
h_new = lambda_val * h + u_val
|
||||
|
||||
zh = jnp.concatenate([z_star, h_new])
|
||||
y = self.out_net(zh)
|
||||
|
||||
return h_new, (y, z_star_debug)
|
||||
|
||||
h_init = jnp.zeros(self.d_state)
|
||||
_, (y, z_star) = jax.lax.scan(scan_fn, h_init, x)
|
||||
|
||||
if self.debug:
|
||||
return y, z_star
|
||||
else:
|
||||
return y
|
||||
|
||||
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
seq_len, _ = x.shape
|
||||
|
||||
def cond_fn(state):
|
||||
i, z, z_prev = state
|
||||
converged = jnp.linalg.norm(z - z_prev) / jnp.linalg.norm(z_prev) < 0.001
|
||||
return (i < self.max_iters) & ~converged
|
||||
|
||||
def body_fn(state):
|
||||
i, z, _ = state
|
||||
|
||||
lambda_vals = self.compute_lambda(z, x)
|
||||
u_vals = self.compute_u(z, x)
|
||||
|
||||
_, h = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
|
||||
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
|
||||
|
||||
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
|
||||
|
||||
return (i + 1, z_new, z)
|
||||
|
||||
# Initialize
|
||||
z_init = jnp.zeros((seq_len, self.d_inner))
|
||||
z_prev_init = jnp.full((seq_len, self.d_inner), jnp.inf)
|
||||
|
||||
# Run fixed-point iteration
|
||||
if self.debug:
|
||||
|
||||
def scan_fn(state, _):
|
||||
new_state = body_fn(state)
|
||||
return new_state, new_state[1]
|
||||
|
||||
_, z_star_debug = jax.lax.scan(
|
||||
scan_fn, (0, z_init, z_prev_init), None, length=self.max_iters
|
||||
)
|
||||
z_star = z_star_debug[-1]
|
||||
else:
|
||||
_, z_star, _ = eqx.internal.while_loop(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
(0, z_init, z_prev_init),
|
||||
max_steps=self.max_iters,
|
||||
kind="bounded",
|
||||
)
|
||||
|
||||
lambda_vals = self.compute_lambda(z_star, x)
|
||||
u_vals = self.compute_u(z_star, x)
|
||||
_, h_star = jax.lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
|
||||
|
||||
zh = jnp.concatenate([z_star, h_star], axis=-1)
|
||||
y_star = jax.vmap(self.out_net)(zh)
|
||||
|
||||
if self.debug:
|
||||
return y_star, z_star_debug
|
||||
else:
|
||||
return y_star
|
||||
|
||||
|
||||
class Boomerang(Implicit):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_state: int = 16,
|
||||
d_inner: int = 32,
|
||||
dt: float = 1.0,
|
||||
max_iters: int | None = 100,
|
||||
max_time: float = 100,
|
||||
solver: Optional[dfx.AbstractSolver] = None,
|
||||
adjoint: Optional[dfx.AbstractAdjoint] = None,
|
||||
rtol: float = 1e-4,
|
||||
atol: float = 1e-3,
|
||||
norm: Normalizer = optx.rms_norm,
|
||||
*,
|
||||
key: PRNGKeyArray,
|
||||
):
|
||||
keys = jrandom.split(key)
|
||||
super(Boomerang, self).__init__(
|
||||
d_model=d_model,
|
||||
d_state=d_state,
|
||||
d_inner=d_inner,
|
||||
dt=dt,
|
||||
max_iters=max_iters,
|
||||
max_time=max_time,
|
||||
solver=solver,
|
||||
adjoint=adjoint,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
norm=norm,
|
||||
key=keys[0],
|
||||
)
|
||||
|
||||
self.f_net = eqx.nn.Linear(
|
||||
d_state + d_model,
|
||||
d_inner,
|
||||
key=keys[1],
|
||||
)
|
||||
|
||||
def f_theta(
|
||||
self,
|
||||
z: Float[Array, " d_inner"],
|
||||
h: Float[Array, " d_state"],
|
||||
x: Float[Array, " d_model"],
|
||||
) -> Float[Array, " d_inner"]:
|
||||
"""f_θ(z, h, x) → z - the implicit function."""
|
||||
hx = jnp.concatenate([h, x])
|
||||
|
||||
alpha_sigma = jnp.split(self.f_net(hx), 2)
|
||||
|
||||
kk = 0.68 # fixed by tech
|
||||
Ut = 0.025 # temperature dependent
|
||||
I_r0 = 0.9
|
||||
x1, x2 = jnp.split(z, 2)
|
||||
|
||||
alpha = 0.000129 + (0.0129 - 0.000129) * jax.nn.sigmoid(
|
||||
alpha_sigma[0]
|
||||
) # The circuit will get directly the current
|
||||
Ia = jax.nn.tanh(alpha_sigma[1]) * 0.6
|
||||
|
||||
x3 = 1 # (np.tanh(20 * (x2 - x1))) #smoother transition
|
||||
dx1 = 2.3 * (
|
||||
1 - (alpha * jnp.exp(kk * x2 / Ut)) * (1 - I_r0 * (0.3 - x1)) + Ia * x3
|
||||
)
|
||||
dx2 = 2.3 * (
|
||||
-1 + (alpha * jnp.exp(kk * x1 / Ut)) * (1 + I_r0 * (0.3 - x2)) + Ia * x3
|
||||
)
|
||||
dz = jnp.concat([dx1, dx2])
|
||||
|
||||
return dz
|
||||
420
src/felice/networks/implicit_snn.py
Normal file
420
src/felice/networks/implicit_snn.py
Normal file
@@ -0,0 +1,420 @@
|
||||
from typing import Tuple
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrandom
|
||||
from jax import flatten_util, lax
|
||||
from jaxtyping import Array, Float, PRNGKeyArray
|
||||
|
||||
|
||||
@jax.jit
|
||||
def binary_op(
|
||||
left: Tuple[Array, Array],
|
||||
right: Tuple[Array, Array],
|
||||
) -> Tuple[Array, Array]:
|
||||
"""
|
||||
(a1, b1) ∘ (a2, b2) = (a1·a2, a2·b1 + b2)
|
||||
"""
|
||||
a1, b1 = left
|
||||
a2, b2 = right
|
||||
return (a1 * a2, a2 * b1 + b2)
|
||||
|
||||
|
||||
class ImplicitSNN(eqx.Module):
|
||||
d_model: int
|
||||
d_state: int
|
||||
d_latent: int
|
||||
max_iters: int
|
||||
tol: float
|
||||
|
||||
embedding: eqx.nn.Embedding
|
||||
f_net: eqx.Module # f_θ: maps (z, h, x) → z
|
||||
f_net2: eqx.Module # f_θ: maps (z, h, x) → z
|
||||
lambda_net: eqx.Module # Λ: maps (z, x) → decay factor (diagonal)
|
||||
u_net: eqx.Module # u: maps (z, x) → input contribution
|
||||
out_net: eqx.nn.Linear # Output: maps (z, h) → y
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int,
|
||||
d_state: int = 16,
|
||||
d_latent: int = 32,
|
||||
max_iters: int = 20,
|
||||
tol: float = 1e-5,
|
||||
*,
|
||||
key: PRNGKeyArray,
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_latent = d_latent
|
||||
self.max_iters = max_iters
|
||||
self.tol = tol
|
||||
|
||||
keys = jrandom.split(key, 6)
|
||||
|
||||
self.embedding = eqx.nn.Embedding(vocab_size, d_model, key=keys[0])
|
||||
|
||||
# f_θ(z, h, x) → z
|
||||
# Input: z (d_latent) + h (d_state) + x (d_model)
|
||||
self.f_net = eqx.nn.Linear(
|
||||
d_state + d_model,
|
||||
d_latent // 2,
|
||||
key=keys[1],
|
||||
)
|
||||
|
||||
self.f_net2 = eqx.nn.Linear(
|
||||
d_state + d_model,
|
||||
d_latent // 2,
|
||||
key=keys[5],
|
||||
)
|
||||
|
||||
# Λ(z, x) → (d_state,) decay factors
|
||||
self.lambda_net = eqx.nn.Linear(d_latent + d_model, d_state, key=keys[2])
|
||||
|
||||
# u(z, x) → (d_state,) input contribution
|
||||
self.u_net = eqx.nn.Linear(d_latent + d_model, d_state, key=keys[3])
|
||||
|
||||
# Output projection
|
||||
self.out_net = eqx.nn.Linear(d_latent + d_state, d_model, key=keys[4])
|
||||
|
||||
def compute_lambda(
|
||||
self, z: Float[Array, " d_latent"], x: Float[Array, " d_model"]
|
||||
) -> Float[Array, " d_state"]:
|
||||
"""Λ(z, x) - the decay/retention factor."""
|
||||
zx = jnp.concatenate([z, x], axis=-1)
|
||||
# Sigmoid to keep in (0, 1) for stability
|
||||
return jax.nn.sigmoid(jax.vmap(self.lambda_net)(zx))
|
||||
|
||||
def compute_u(
|
||||
self, z: Float[Array, " d_latent"], x: Float[Array, " d_model"]
|
||||
) -> Float[Array, " d_state"]:
|
||||
"""u(z, x) - the input contribution."""
|
||||
zx = jnp.concatenate([z, x], axis=-1)
|
||||
return jax.vmap(self.u_net)(zx)
|
||||
|
||||
# TODO: Change for diffrax
|
||||
def f_theta(
|
||||
self,
|
||||
z: Float[Array, " d_latent"],
|
||||
h: Float[Array, " d_state"],
|
||||
x: Float[Array, " d_model"],
|
||||
) -> Float[Array, " d_latent"]:
|
||||
"""f_θ(z, h, x) → z - the implicit function."""
|
||||
hx = jnp.concatenate([h, x])
|
||||
|
||||
rho = 30.0
|
||||
alpha = jax.nn.sigmoid(self.f_net(hx))
|
||||
beta = 15.6
|
||||
gamma = 0.26
|
||||
sigma = jax.nn.sigmoid(self.f_net2(hx))
|
||||
u, v = jnp.split(z, 2)
|
||||
|
||||
thresh = jax.nn.tanh(rho * (v - u))
|
||||
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * thresh
|
||||
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * thresh
|
||||
dz = jnp.concat([du, dv])
|
||||
z = z + 0.001 * dz
|
||||
|
||||
return z
|
||||
|
||||
# ==================== Parallel mode (training) ====================
|
||||
|
||||
def parallel_scan_h(
|
||||
self,
|
||||
lambda_vals: Float[Array, "seq d_state"],
|
||||
u_vals: Float[Array, "seq d_state"],
|
||||
) -> Float[Array, "seq d_state"]:
|
||||
"""
|
||||
Compute h sequence via parallel scan.
|
||||
|
||||
h_t = λ_t · h_{t-1} + u_t
|
||||
|
||||
This is the standard SSM recurrence, parallelizable via associative scan.
|
||||
"""
|
||||
|
||||
# Elements: (λ_t, u_t)
|
||||
# After scan: (cumulative λ, h_t)
|
||||
_, h = lax.associative_scan(binary_op, (lambda_vals, u_vals), axis=0)
|
||||
|
||||
return h
|
||||
|
||||
def debug(self, x):
|
||||
x = jax.vmap(self.embedding)(x) # (seq, d_model)
|
||||
seq_len = x.shape[0]
|
||||
|
||||
def body_fn(state, _):
|
||||
z, _ = state
|
||||
|
||||
# Compute λ, u
|
||||
lambda_vals = self.compute_lambda(z, x)
|
||||
u_vals = self.compute_u(z, x)
|
||||
|
||||
# Parallel scan for h
|
||||
h = self.parallel_scan_h(lambda_vals, u_vals)
|
||||
|
||||
# Shift h
|
||||
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
|
||||
|
||||
# Update z
|
||||
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
|
||||
|
||||
return (z_new, z), z_new
|
||||
|
||||
# Initialize
|
||||
z_init = jnp.zeros((seq_len, self.d_latent))
|
||||
z_prev_init = jnp.full((seq_len, self.d_latent), jnp.inf)
|
||||
|
||||
_, z_star = jax.lax.scan(
|
||||
body_fn, (z_init, z_prev_init), None, length=self.max_iters
|
||||
)
|
||||
|
||||
return z_star
|
||||
|
||||
def forward_parallel(
|
||||
self,
|
||||
x: Float[Array, " seq"],
|
||||
) -> Float[Array, "seq d_model"]:
|
||||
"""
|
||||
Parallel forward pass for training.
|
||||
|
||||
Algorithm:
|
||||
1. Initialize z for all positions
|
||||
2. Iterate until convergence:
|
||||
a. Compute λ, u from current z (parallel over positions)
|
||||
b. Compute h via parallel scan
|
||||
c. Update z = f_θ(z, h_shifted, x) (parallel over positions)
|
||||
3. Compute output from final z, h
|
||||
|
||||
Note: h_shifted means h_{t-1} for position t, so we shift h right.
|
||||
"""
|
||||
x = jax.vmap(self.embedding)(x) # (seq, d_model)
|
||||
seq_len = x.shape[0]
|
||||
|
||||
def cond_fn(state):
|
||||
i, z, z_prev = state
|
||||
converged = jnp.linalg.norm(z - z_prev) < self.tol
|
||||
return (i < self.max_iters) & ~converged
|
||||
|
||||
def body_fn(state):
|
||||
i, z, _ = state
|
||||
|
||||
# Compute λ, u
|
||||
lambda_vals = self.compute_lambda(z, x)
|
||||
u_vals = self.compute_u(z, x)
|
||||
|
||||
# Parallel scan for h
|
||||
h = self.parallel_scan_h(lambda_vals, u_vals)
|
||||
|
||||
# Shift h
|
||||
h_prev = jnp.concatenate([jnp.zeros((1, self.d_state)), h[:-1]], axis=0)
|
||||
|
||||
# Update z
|
||||
z_new = jax.vmap(self.f_theta)(z, h_prev, x)
|
||||
|
||||
return (i + 1, z_new, z)
|
||||
|
||||
# Initialize
|
||||
z_init = jnp.zeros((seq_len, self.d_latent))
|
||||
z_prev_init = jnp.full((seq_len, self.d_latent), jnp.inf)
|
||||
|
||||
# Run fixed-point iteration
|
||||
_, z_star, _ = eqx.internal.while_loop(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
(0, z_init, z_prev_init),
|
||||
max_steps=self.max_iters,
|
||||
kind="bounded",
|
||||
)
|
||||
|
||||
# Final h computation with converged z
|
||||
lambda_vals = self.compute_lambda(z_star, x)
|
||||
u_vals = self.compute_u(z_star, x)
|
||||
h_star = self.parallel_scan_h(lambda_vals, u_vals)
|
||||
|
||||
# Output
|
||||
zh = jnp.concatenate([z_star, h_star], axis=-1)
|
||||
y_star = jax.vmap(self.out_net)(zh)
|
||||
|
||||
return y_star
|
||||
|
||||
# ==================== Sequential mode (inference) ====================
|
||||
|
||||
def find_fixed_point(
|
||||
self,
|
||||
h_prev: Float[Array, " d_state"],
|
||||
x: Float[Array, " d_model"],
|
||||
) -> Float[Array, " d_latent"]:
|
||||
"""
|
||||
Find z* such that z* = f_θ(z*, h_prev, x)
|
||||
using fixed-point iteration.
|
||||
"""
|
||||
|
||||
def cond_fn(state):
|
||||
i, z, z_prev = state
|
||||
converged = jnp.linalg.norm(z - z_prev) < self.tol
|
||||
return (i < self.max_iters) & ~converged
|
||||
|
||||
def body_fn(state):
|
||||
i, z, _ = state
|
||||
z_new = self.f_theta(z, h_prev, x)
|
||||
return (i + 1, z_new, z)
|
||||
|
||||
# Initialize z
|
||||
z_init = jnp.zeros(self.d_latent)
|
||||
z_prev_init = jnp.full(self.d_latent, jnp.inf)
|
||||
|
||||
_, z_star, _ = eqx.internal.while_loop(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
(0, z_init, z_prev_init),
|
||||
max_steps=self.max_iters,
|
||||
kind="bounded",
|
||||
)
|
||||
|
||||
return z_star
|
||||
|
||||
def step(
|
||||
self,
|
||||
h_prev: Float[Array, " d_state"],
|
||||
x: Float[Array, " d_model"],
|
||||
) -> Tuple[Float[Array, " d_state"], Float[Array, " d_model"]]:
|
||||
"""
|
||||
Single step of the implicit SSM.
|
||||
|
||||
1. Find z* via fixed-point iteration
|
||||
2. Compute h* = Λ(z*, x) · h_prev + u(z*, x)
|
||||
3. Compute output y
|
||||
"""
|
||||
# Find fixed point z*
|
||||
z_star = self.find_fixed_point(h_prev, x)
|
||||
|
||||
# Compute new hidden state
|
||||
lambda_val = self.compute_lambda(z_star[jnp.newaxis, :], x[jnp.newaxis, :])[
|
||||
0
|
||||
] # (d_state,)
|
||||
u_val = self.compute_u(z_star[jnp.newaxis, :], x[jnp.newaxis, :])[
|
||||
0
|
||||
] # (d_state,)
|
||||
|
||||
h_new = lambda_val * h_prev + u_val
|
||||
|
||||
# Compute output
|
||||
zh = jnp.concatenate([z_star, h_new])
|
||||
y = self.out_net(zh)
|
||||
|
||||
return h_new, y
|
||||
|
||||
def forward_sequential(
|
||||
self,
|
||||
x: Float[Array, " seq"],
|
||||
) -> Float[Array, "seq d_model"]:
|
||||
"""
|
||||
Sequential forward pass for inference.
|
||||
|
||||
Processes one token at a time, maintaining state.
|
||||
"""
|
||||
x = jax.vmap(self.embedding)(x)
|
||||
|
||||
def scan_fn(h, x_t):
|
||||
h_new, y_t = self.step(h, x_t)
|
||||
return h_new, y_t
|
||||
|
||||
h_init = jnp.zeros(self.d_state)
|
||||
_, y = lax.scan(scan_fn, h_init, x)
|
||||
|
||||
return y
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Float[Array, "seq d_model"],
|
||||
mode: str = "parallel",
|
||||
) -> Float[Array, "seq d_model"]:
|
||||
"""
|
||||
Forward pass over sequence.
|
||||
Note: This is inherently sequential due to the implicit nature.
|
||||
"""
|
||||
|
||||
if mode == "parallel":
|
||||
return self.forward_parallel(x)
|
||||
else:
|
||||
return self.forward_sequential(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
key = jrandom.key(42)
|
||||
keys = jrandom.split(key, 2)
|
||||
|
||||
d_model, d_state, d_latent = 64, 16, 32
|
||||
seq_len = 1080
|
||||
|
||||
model = ImplicitSNN(
|
||||
vocab_size=16,
|
||||
d_model=d_model,
|
||||
d_state=d_state,
|
||||
d_latent=d_latent,
|
||||
max_iters=10,
|
||||
key=keys[0],
|
||||
)
|
||||
|
||||
# Test input
|
||||
x = jrandom.randint(keys[1], (seq_len,), 0, 15)
|
||||
|
||||
# Test parallel mode
|
||||
print("Testing parallel mode...")
|
||||
y_parallel = model(x, mode="parallel")
|
||||
print(f" Input: {x.shape}, Output: {y_parallel.shape}")
|
||||
|
||||
# Test sequential mode
|
||||
print("\nTesting sequential mode...")
|
||||
y_sequential = model(x, mode="sequential")
|
||||
print(f" Input: {x.shape}, Output: {y_sequential.shape}")
|
||||
|
||||
# Compare outputs (should be close but not exact due to different convergence paths)
|
||||
diff = jnp.linalg.norm(y_parallel - y_sequential) / jnp.linalg.norm(y_sequential)
|
||||
print(f"\nRelative difference: {diff:.6f}")
|
||||
|
||||
# Test gradients in parallel mode
|
||||
def loss_fn(model, x, mode):
|
||||
return jnp.mean(model(x, mode=mode) ** 2)
|
||||
|
||||
print("\nTesting gradients (parallel mode)...")
|
||||
grads_par = eqx.filter_grad(loss_fn)(model, x, "parallel")
|
||||
grads_par = flatten_util.ravel_pytree(grads_par)[0]
|
||||
print("\nTesting gradients (sequential mode)...")
|
||||
grads_seq = eqx.filter_grad(loss_fn)(model, x, "sequential")
|
||||
grads_seq = flatten_util.ravel_pytree(grads_seq)[0]
|
||||
|
||||
grad_diff = jnp.linalg.norm(grads_par - grads_seq) / jnp.linalg.norm(grads_seq)
|
||||
print(" Gradients computed successfully!")
|
||||
print(f"\nRelative difference: {grad_diff:.6f}")
|
||||
|
||||
# Benchmark
|
||||
print("\nBenchmarking...")
|
||||
|
||||
# JIT compile
|
||||
forward_parallel_jit = eqx.filter_jit(lambda m, x: m(x, mode="parallel"))
|
||||
forward_sequential_jit = eqx.filter_jit(lambda m, x: m(x, mode="sequential"))
|
||||
|
||||
# Warmup
|
||||
_ = forward_parallel_jit(model, x)
|
||||
_ = forward_sequential_jit(model, x)
|
||||
|
||||
# Time parallel
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
_ = forward_parallel_jit(model, x).block_until_ready()
|
||||
parallel_time = (time.time() - start) / 100
|
||||
|
||||
# Time sequential
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
_ = forward_sequential_jit(model, x).block_until_ready()
|
||||
sequential_time = (time.time() - start) / 100
|
||||
|
||||
print(f" Parallel: {parallel_time * 1000:.3f} ms")
|
||||
print(f" Sequential: {sequential_time * 1000:.3f} ms")
|
||||
print(f" Speedup: {sequential_time / parallel_time:.2f}x")
|
||||
137
src/felice/networks/ssm.py
Normal file
137
src/felice/networks/ssm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import math
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrandom
|
||||
from einops import repeat
|
||||
from jaxtyping import Array, Float, PRNGKeyArray
|
||||
|
||||
from .utils import binary_op
|
||||
|
||||
|
||||
class SelectiveSSM(eqx.Module):
|
||||
d_model: int = eqx.field(static=True)
|
||||
d_state: int = eqx.field(static=True)
|
||||
dt_rank: int = eqx.field(static=True)
|
||||
|
||||
x_proj: eqx.nn.Linear
|
||||
dt_proj: eqx.nn.Linear
|
||||
|
||||
A_log: Float[Array, "d_inner d_state"]
|
||||
D: Float[Array, " d_inner"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
dt_rank: int,
|
||||
d_state: int = 16,
|
||||
*,
|
||||
key=PRNGKeyArray,
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.dt_rank = dt_rank
|
||||
|
||||
keys = jrandom.split(key, 5)
|
||||
|
||||
self.x_proj = eqx.nn.Linear(
|
||||
self.d_model, self.dt_rank + self.d_state * 2, use_bias=False, key=keys[3]
|
||||
)
|
||||
self.dt_proj = eqx.nn.Linear(
|
||||
self.dt_rank, self.d_model, use_bias=True, key=keys[4]
|
||||
)
|
||||
|
||||
# S4D-Real initialization
|
||||
A = repeat(
|
||||
jnp.arange(1, d_state + 1, dtype=jnp.float32),
|
||||
"n -> d n",
|
||||
d=self.d_model,
|
||||
)
|
||||
self.A_log = jnp.log(A)
|
||||
|
||||
self.D = jnp.ones(self.d_model)
|
||||
|
||||
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
x_dbl = jax.vmap(self.x_proj)(x)
|
||||
dt, B, C = jnp.split(
|
||||
x_dbl, [self.dt_rank, self.dt_rank + self.d_state], axis=-1
|
||||
)
|
||||
|
||||
dt = jax.vmap(self.dt_proj)(dt)
|
||||
dt = jax.nn.softplus(dt)
|
||||
|
||||
A = -jnp.exp(self.A_log) # (d_inner, d_state)
|
||||
dA = jnp.exp(jnp.einsum("ld,dn->ldn", dt, A))
|
||||
dB_x = jnp.einsum("ld,ln,ld->ldn", dt, B, x)
|
||||
|
||||
_, h = jax.lax.associative_scan(binary_op, (dA, dB_x), axis=0)
|
||||
y = jnp.einsum("ldn,ln->ld", h, C)
|
||||
|
||||
y = y + x * self.D
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Mamba(eqx.Module):
|
||||
d_model: int = eqx.field(static=True)
|
||||
d_conv: int = eqx.field(static=True)
|
||||
d_inner: int = eqx.field(static=True)
|
||||
|
||||
in_proj: eqx.nn.Linear
|
||||
out_proj: eqx.nn.Linear
|
||||
conv1d: eqx.nn.Conv1d
|
||||
|
||||
ssm: SelectiveSSM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_state: int = 16,
|
||||
d_inner: int = 32,
|
||||
d_conv: int = 4,
|
||||
*,
|
||||
key=PRNGKeyArray,
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.d_conv = d_conv
|
||||
self.d_inner = d_inner
|
||||
|
||||
keys = jrandom.split(key, 4)
|
||||
|
||||
self.in_proj = eqx.nn.Linear(self.d_model, self.d_inner * 2, key=keys[0])
|
||||
self.conv1d = eqx.nn.Conv1d(
|
||||
in_channels=self.d_inner,
|
||||
out_channels=self.d_inner,
|
||||
kernel_size=d_conv,
|
||||
groups=self.d_inner,
|
||||
padding=d_conv - 1,
|
||||
key=keys[1],
|
||||
)
|
||||
|
||||
self.ssm = SelectiveSSM(
|
||||
d_model=d_inner,
|
||||
dt_rank=math.ceil(self.d_model / 16),
|
||||
d_state=d_state,
|
||||
key=keys[2],
|
||||
)
|
||||
self.out_proj = eqx.nn.Linear(self.d_inner, self.d_model, key=keys[3])
|
||||
|
||||
def __call__(self, x: Float[Array, "seq d_model"]) -> Float[Array, "seq d_model"]:
|
||||
seq_len, _ = x.shape
|
||||
|
||||
# Projec the input into the convolution and residual
|
||||
xz = jax.vmap(self.in_proj)(x)
|
||||
x, z = jnp.split(xz, 2, axis=-1)
|
||||
|
||||
# 1D Convolution
|
||||
x = x.T
|
||||
x = self.conv1d(x)[:, :seq_len]
|
||||
x = x.T
|
||||
x = jax.nn.silu(x)
|
||||
|
||||
y = self.ssm(x)
|
||||
y = y * jax.nn.silu(z)
|
||||
|
||||
logits = jax.vmap(self.out_proj)(y)
|
||||
return logits
|
||||
13
src/felice/networks/utils.py
Normal file
13
src/felice/networks/utils.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jaxtyping import Array
|
||||
|
||||
|
||||
@jax.jit
|
||||
def binary_op(
|
||||
left: Tuple[Array, Array], right: Tuple[Array, Array]
|
||||
) -> Tuple[Array, Array]:
|
||||
a1, b1 = left
|
||||
a2, b2 = right
|
||||
return (a1 * a2, a2 * b1 + b2)
|
||||
5
src/felice/neuron_models/__init__.py
Normal file
5
src/felice/neuron_models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .boomerang import Boomerang
|
||||
from .fhn import FHNRS
|
||||
from .wererabbit import WereRabbit
|
||||
|
||||
__all__ = ["WereRabbit", "FHNRS", "Boomerang"]
|
||||
171
src/felice/neuron_models/boomerang.py
Normal file
171
src/felice/neuron_models/boomerang.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optimistix as optx
|
||||
from jaxtyping import Array, DTypeLike, Float
|
||||
|
||||
|
||||
class Boomerang(eqx.Module):
|
||||
rtol: float = eqx.field(static=True)
|
||||
atol: float = eqx.field(static=True)
|
||||
|
||||
u0: float = eqx.field(static=True)
|
||||
v0: float = eqx.field(static=True)
|
||||
|
||||
alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio
|
||||
beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)
|
||||
gamma: float = eqx.field(static=True) # coupling coefficient
|
||||
rho: float = eqx.field(static=True) # tanh steepness
|
||||
sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)
|
||||
|
||||
dtype: DTypeLike = eqx.field(static=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
atol: float = 1e-6,
|
||||
rtol: float = 1e-4,
|
||||
alpha: float = 0.0129,
|
||||
beta: float = 15.6,
|
||||
gamma: float = 0.26,
|
||||
rho: float = 30.0,
|
||||
sigma: float = 0.6,
|
||||
dtype: DTypeLike = jnp.float32,
|
||||
):
|
||||
r"""Initialize the WereRabbit neuron model.
|
||||
|
||||
Args:
|
||||
key: JAX random key for weight initialization.
|
||||
n_neurons: Number of neurons in this layer.
|
||||
in_size: Number of input connections (excluding recurrent connections).
|
||||
wmask: Binary mask defining connectivity pattern of shape (in_plus_neurons, neurons).
|
||||
rtol: Relative tolerance for the spiking fixpoint calculation.
|
||||
atol: Absolute tolerance for the spiking fixpoint calculation.
|
||||
alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
|
||||
beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
|
||||
gamma: Coupling parameter $\gamma = 26e^{-2}$
|
||||
rho: Steepness of the tanh function $\rho$ (default: 5)
|
||||
sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
|
||||
wlim: Limit for weight initialization. If None, uses init_weights.
|
||||
wmean: Mean value for weight initialization.
|
||||
init_weights: Optional initial weight values. If None, weights are randomly initialized.
|
||||
fan_in_mode: Mode for fan-in based weight initialization ('sqrt', 'linear').
|
||||
dtype: Data type for arrays (default: float32).
|
||||
"""
|
||||
self.dtype = dtype
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.rho = rho
|
||||
self.sigma = sigma
|
||||
|
||||
self.rtol = rtol
|
||||
self.atol = atol
|
||||
|
||||
def fn(y, _):
|
||||
return self.vector_field(y[0], y[1])
|
||||
|
||||
solver: optx.AbstractRootFinder = optx.Newton(rtol=1e-8, atol=1e-8)
|
||||
y0 = (jnp.array(0.3), jnp.array(0.3))
|
||||
u0, v0 = optx.root_find(fn, solver, y0).value
|
||||
self.u0 = u0.item()
|
||||
self.v0 = v0.item()
|
||||
|
||||
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
|
||||
"""Initialize the neuron state variables.
|
||||
|
||||
Args:
|
||||
n_neurons: Number of neurons to initialize.
|
||||
|
||||
Returns:
|
||||
Initial state array of shape (neurons, 3) containing [u, v],
|
||||
where u and v are the predator/prey membrane voltages.
|
||||
"""
|
||||
|
||||
u = jnp.full((n_neurons,), self.u0, dtype=self.dtype)
|
||||
v = jnp.full((n_neurons,), self.v0, dtype=self.dtype)
|
||||
x = jnp.stack([u, v], axis=1)
|
||||
return x
|
||||
|
||||
def vector_field(
|
||||
self, u: Float[Array, "..."], v: Float[Array, "..."]
|
||||
) -> Tuple[Float[Array, "..."], Float[Array, "..."]]:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
gamma = self.gamma
|
||||
sigma = self.sigma
|
||||
rho = self.rho
|
||||
|
||||
z = jax.nn.tanh(rho * (v - u))
|
||||
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z
|
||||
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z
|
||||
|
||||
return du, dv
|
||||
|
||||
def dynamics(
|
||||
self,
|
||||
t: float,
|
||||
y: Float[Array, "neurons 2"],
|
||||
args: Dict[str, Any],
|
||||
) -> Float[Array, "neurons 2"]:
|
||||
"""Compute time derivatives of the neuron state variables.
|
||||
|
||||
This implements the WereRabbit dynamics
|
||||
|
||||
- du/dt: Predator dynamics
|
||||
- dv/dt: WerePrey dynamics
|
||||
|
||||
Args:
|
||||
t: Current simulation time (unused but required by framework).
|
||||
y: State array of shape (neurons, 2) containing [u, v].
|
||||
args: Additional arguments (unused but required by framework).
|
||||
|
||||
Returns:
|
||||
Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
|
||||
"""
|
||||
u = y[:, 0]
|
||||
v = y[:, 1]
|
||||
|
||||
du, dv = self.vector_field(u, v)
|
||||
dxdt = jnp.stack([du, dv], axis=1)
|
||||
|
||||
return dxdt
|
||||
|
||||
def spike_condition(
|
||||
self,
|
||||
t: float,
|
||||
y: Float[Array, "neurons 2"],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Float[Array, " neurons"]:
|
||||
"""Compute spike condition for event detection.
|
||||
|
||||
A spike is triggered when the system reach to a fixpoint.
|
||||
|
||||
INFO:
|
||||
`has_spiked` is use to the system don't detect a continuos
|
||||
spike when reach a fixpoint.
|
||||
|
||||
Args:
|
||||
t: Current simulation time (unused but required by the framework).
|
||||
y: State array of shape (neurons, 3) containing [u, v, has_spiked].
|
||||
**kwargs: Additional keyword arguments (unused).
|
||||
|
||||
Returns:
|
||||
Spike condition array of shape (neurons,). Positive values indicate spike.
|
||||
"""
|
||||
_atol = self.atol
|
||||
_rtol = self.rtol
|
||||
_norm = optx.rms_norm
|
||||
|
||||
vf = self.dynamics(t, y, {})
|
||||
|
||||
@jax.vmap
|
||||
def calculate_norm(vf, y):
|
||||
return _atol + _rtol * _norm(y) - _norm(vf)
|
||||
|
||||
base_cond = calculate_norm(vf, y).repeat(2)
|
||||
|
||||
return base_cond
|
||||
235
src/felice/neuron_models/fhn.py
Normal file
235
src/felice/neuron_models/fhn.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import equinox as eqx
|
||||
import jax.numpy as jnp
|
||||
from jaxtyping import Array, DTypeLike, Float
|
||||
|
||||
|
||||
class FHNRS(eqx.Module):
|
||||
r"""FitzHugh-Nagumo neuron model
|
||||
|
||||
Model for FitzHugh-Nagumo neuron, with a hardware implementation proposed by
|
||||
Ribar-Sepulchre. This implementation uses a dual-timescale dynamics with fast
|
||||
and slow currents to produce oscillatory spiking behavior.
|
||||
|
||||
The dynamics are governed by:
|
||||
|
||||
$$
|
||||
\begin{align}
|
||||
C\frac{dv}{dt} &= I_{app} - I_{passive} - I_{fast} - I_{slow} \\
|
||||
\frac{dv_{slow}}{dt} &= \frac{v - v_{slow}}{\tau_{slow}} \\
|
||||
\frac{dI_{app}}{dt} &= -\frac{I_{app}}{\tau_{syn}}
|
||||
\end{align}
|
||||
$$
|
||||
|
||||
where the currents are:
|
||||
|
||||
- $I_{passive} = g_{max}(v - E_{rev})$
|
||||
- $I_{fast} = a_{fast} \tanh(v - v_{off,fast})$
|
||||
- $I_{slow} = a_{slow} \tanh(v_{slow} - v_{off,slow})$
|
||||
|
||||
References:
|
||||
- Ribar, L., & Sepulchre, R. (2019). Neuromodulation of neuromorphic circuits. IEEE Transactions on Circuits and Systems I: Regular Papers, 66(8), 3028-3040.
|
||||
|
||||
Attributes:
|
||||
reset_grad_preserve: Preserve the gradient when the neuron spikes by doing a soft reset.
|
||||
gmax_pasive: Maximal conductance of the passive current.
|
||||
Erev_pasive: Reversal potential for the passive current.
|
||||
a_fast: Amplitude parameter for the fast current dynamics.
|
||||
voff_fast: Voltage offset for the fast current activation.
|
||||
tau_fast: Time constant for the fast current (typically zero for instantaneous).
|
||||
a_slow: Amplitude parameter for the slow current dynamics.
|
||||
voff_slow: Voltage offset for the slow current activation.
|
||||
tau_slow: Time constant for the slow recovery variable.
|
||||
vthr: Voltage threshold for spike generation.
|
||||
C: Membrane capacitance.
|
||||
tsyn: Synaptic time constant for input current decay.
|
||||
weights: Synaptic weight matrix of shape (in_plus_neurons, neurons).
|
||||
"""
|
||||
|
||||
# Pasive parameters
|
||||
gmax_pasive: float = eqx.field(static=True)
|
||||
Erev_pasive: float = eqx.field(static=True)
|
||||
|
||||
# Fast current
|
||||
a_fast: float = eqx.field(static=True)
|
||||
voff_fast: float = eqx.field(static=True)
|
||||
tau_fast: float = eqx.field(static=True)
|
||||
|
||||
# Slow current
|
||||
a_slow: float = eqx.field(static=True)
|
||||
voff_slow: float = eqx.field(static=True)
|
||||
tau_slow: float = eqx.field(static=True)
|
||||
|
||||
# Neuron threshold
|
||||
vthr: float = eqx.field(static=True)
|
||||
C: float = eqx.field(static=True, default=1.0)
|
||||
|
||||
# Input synaptic time constant
|
||||
tsyn: float = eqx.field(static=True)
|
||||
|
||||
dtype: DTypeLike = eqx.field(static=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tsyn: Union[int, float, jnp.ndarray] = 1.0,
|
||||
C: Union[int, float, jnp.ndarray] = 1.0,
|
||||
gmax_pasive: Union[int, float, jnp.ndarray] = 1.0,
|
||||
Erev_pasive: Union[int, float, jnp.ndarray] = 0.0,
|
||||
a_fast: Union[int, float, jnp.ndarray] = -2.0,
|
||||
voff_fast: Union[int, float, jnp.ndarray] = 0.0,
|
||||
tau_fast: Union[int, float, jnp.ndarray] = 0.0,
|
||||
a_slow: Union[int, float, jnp.ndarray] = 2.0,
|
||||
voff_slow: Union[int, float, jnp.ndarray] = 0.0,
|
||||
tau_slow: Union[int, float, jnp.ndarray] = 50.0,
|
||||
vthr: Union[int, float, jnp.ndarray] = 2.0,
|
||||
dtype: DTypeLike = jnp.float32,
|
||||
):
|
||||
"""Initialize the FitzHugh-Nagumo neuron model.
|
||||
|
||||
Args:
|
||||
tsyn: Synaptic time constant for input current decay. Can be scalar or per-neuron array.
|
||||
C: Membrane capacitance. Can be scalar or per-neuron array.
|
||||
gmax_pasive: Maximal conductance of passive current. Can be scalar or per-neuron array.
|
||||
Erev_pasive: Reversal potential for passive current. Can be scalar or per-neuron array.
|
||||
a_fast: Amplitude of fast current. Can be scalar or per-neuron array.
|
||||
voff_fast: Voltage offset for fast current activation. Can be scalar or per-neuron array.
|
||||
tau_fast: Time constant for fast current (typically 0 for instantaneous). Can be scalar or per-neuron array.
|
||||
a_slow: Amplitude of slow current. Can be scalar or per-neuron array.
|
||||
voff_slow: Voltage offset for slow current activation. Can be scalar or per-neuron array.
|
||||
tau_slow: Time constant for slow recovery variable. Can be scalar or per-neuron array.
|
||||
vthr: Voltage threshold for spike generation. Can be scalar or per-neuron array.
|
||||
dtype: Data type for arrays (default: float32).
|
||||
"""
|
||||
self.dtype = dtype
|
||||
|
||||
self.tsyn = tsyn
|
||||
self.C = C
|
||||
self.gmax_pasive = gmax_pasive
|
||||
self.Erev_pasive = Erev_pasive
|
||||
self.a_fast = a_fast
|
||||
self.voff_fast = voff_fast
|
||||
self.tau_fast = tau_fast
|
||||
self.a_slow = a_slow
|
||||
self.voff_slow = voff_slow
|
||||
self.tau_slow = tau_slow
|
||||
self.vthr = vthr
|
||||
|
||||
def init_state(self, n_neurons: int) -> Float[Array, "neurons 3"]:
|
||||
"""Initialize the neuron state variables.
|
||||
|
||||
Args:
|
||||
n_neurons: Number of neurons to initialize.
|
||||
|
||||
Returns:
|
||||
Initial state array of shape (neurons, 3) containing [v, v_slow, i_app],
|
||||
where v is membrane voltage, v_slow is the slow recovery variable,
|
||||
and i_app is the applied synaptic current.
|
||||
"""
|
||||
return jnp.zeros((n_neurons, 3), dtype=self.dtype)
|
||||
|
||||
def IV_inst(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
|
||||
"""Compute instantaneous I-V relationship with fast and slow currents at rest.
|
||||
|
||||
Args:
|
||||
v: Membrane voltage.
|
||||
Vrest: Resting voltage for both fast and slow currents (default: 0).
|
||||
|
||||
Returns:
|
||||
Total current at voltage v with both fast and slow currents evaluated at Vrest.
|
||||
"""
|
||||
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
|
||||
I_fast = self.a_fast * jnp.tanh(Vrest - self.voff_fast)
|
||||
I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)
|
||||
|
||||
return I_pasive + I_fast + I_slow
|
||||
|
||||
def IV_fast(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
|
||||
"""Compute I-V relationship with fast current at voltage v and slow current at rest.
|
||||
|
||||
Args:
|
||||
v: Membrane voltage for passive and fast currents.
|
||||
Vrest: Resting voltage for slow current (default: 0).
|
||||
|
||||
Returns:
|
||||
Total current with fast dynamics responding to v and slow current at Vrest.
|
||||
"""
|
||||
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
|
||||
I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
|
||||
I_slow = self.a_slow * jnp.tanh(Vrest - self.voff_slow)
|
||||
|
||||
return I_pasive + I_fast + I_slow
|
||||
|
||||
def IV_slow(self, v: Float[Array, "..."], Vrest: float = 0) -> Float[Array, "..."]:
|
||||
"""Compute steady-state I-V relationship with all currents at voltage v.
|
||||
|
||||
Args:
|
||||
v: Membrane voltage for all currents.
|
||||
Vrest: Unused parameter for API consistency (default: 0).
|
||||
|
||||
Returns:
|
||||
Total steady-state current with all currents responding to v.
|
||||
"""
|
||||
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
|
||||
I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
|
||||
I_slow = self.a_slow * jnp.tanh(v - self.voff_slow)
|
||||
|
||||
return I_pasive + I_fast + I_slow
|
||||
|
||||
def dynamics(
|
||||
self,
|
||||
t: float,
|
||||
y: Float[Array, "neurons 3"],
|
||||
args: Dict[str, Any],
|
||||
) -> Float[Array, "neurons 3"]:
|
||||
"""Compute time derivatives of the neuron state variables.
|
||||
|
||||
This implements the FitzHugh-Nagumo dynamics with passive, fast, and slow currents:
|
||||
- dv/dt: Fast membrane voltage dynamics
|
||||
- dv_slow/dt: Slow recovery variable dynamics
|
||||
- di_app/dt: Synaptic current decay
|
||||
|
||||
Args:
|
||||
t: Current simulation time (unused but required by framework).
|
||||
y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
|
||||
args: Additional arguments (unused but required by framework).
|
||||
|
||||
Returns:
|
||||
Time derivatives of shape (neurons, 3) containing [dv/dt, dv_slow/dt, di_app/dt].
|
||||
"""
|
||||
v = y[:, 0]
|
||||
v_slow = y[:, 1]
|
||||
i_app = y[:, 2]
|
||||
|
||||
I_pasive = self.gmax_pasive * (v - self.Erev_pasive)
|
||||
I_fast = self.a_fast * jnp.tanh(v - self.voff_fast)
|
||||
I_slow = self.a_slow * jnp.tanh(v_slow - self.voff_slow)
|
||||
|
||||
i_sum = I_pasive + I_fast + I_slow
|
||||
|
||||
dv_dt = (i_app - i_sum) / self.C
|
||||
dvslow_dt = (v - v_slow) / self.tau_slow
|
||||
di_dt = -i_app / self.tsyn
|
||||
|
||||
return jnp.stack([dv_dt, dvslow_dt, di_dt], axis=1)
|
||||
|
||||
def spike_condition(
|
||||
self,
|
||||
t: float,
|
||||
y: Float[Array, "neurons 3"],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Float[Array, " neurons"]:
|
||||
"""Compute spike condition for event detection.
|
||||
|
||||
A spike is triggered when this function crosses zero (v >= vthr).
|
||||
|
||||
Args:
|
||||
t: Current simulation time (unused but required by event detection).
|
||||
y: State array of shape (neurons, 3) containing [v, v_slow, i_app].
|
||||
**kwargs: Additional keyword arguments (unused).
|
||||
|
||||
Returns:
|
||||
Spike condition array of shape (neurons,). Positive values indicate v > vthr.
|
||||
"""
|
||||
return y[:, 0] - self.vthr
|
||||
194
src/felice/neuron_models/wererabbit.py
Normal file
194
src/felice/neuron_models/wererabbit.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optimistix as optx
|
||||
from jaxtyping import Array, DTypeLike, Float
|
||||
|
||||
|
||||
class WereRabbit(eqx.Module):
|
||||
r"""
|
||||
WereRabbit Neuron Model
|
||||
|
||||
The WereRabbit model implements a predator-prey dynamic with bistable
|
||||
switching behavior controlled by a "moon phase" parameter $z$.
|
||||
|
||||
The dynamics are governed by:
|
||||
|
||||
$$
|
||||
\begin{align}
|
||||
z &= tanh(\rho (u-v)) \\
|
||||
\frac{du}{dt} &= z - z \alpha e^{\beta v} [1 + \gamma (0.5 - u)] - \sigma \\
|
||||
\frac{dv}{dt} &= -z - z \alpha e^{\beta u} [1 + \gamma (0.5 - v)] - \sigma
|
||||
\end{align}
|
||||
$$
|
||||
|
||||
where $z$ represents the "moon phase" that switches the predator-prey roles.
|
||||
|
||||
Attributes:
|
||||
alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
|
||||
beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
|
||||
gamma: Coupling parameter $\gamma = 26e^{-2}$
|
||||
rho: Steepness of the tanh function $\rho$ (default: 5)
|
||||
sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
|
||||
|
||||
rtol: Relative tolerance for the spiking fixpoint calculation.
|
||||
atol: Absolute tolerance for the spiking fixpoint calculation.
|
||||
|
||||
weight_u: Input weight for the predator.
|
||||
weight_v: Input weight for the prey.
|
||||
"""
|
||||
|
||||
dtype: DTypeLike = eqx.field(static=True)
|
||||
rtol: float = eqx.field(static=True)
|
||||
atol: float = eqx.field(static=True)
|
||||
|
||||
alpha: float = eqx.field(static=True) # I_n0 / I_bias ratio
|
||||
beta: float = eqx.field(static=True) # k / U_t (inverse thermal scale)
|
||||
gamma: float = eqx.field(static=True) # coupling coefficient
|
||||
rho: float = eqx.field(static=True) # tanh steepness
|
||||
sigma: float = eqx.field(static=True) # bias scaling (s * I_bias)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
atol: float = 1e-3,
|
||||
rtol: float = 1e-3,
|
||||
alpha: float = 0.0129,
|
||||
beta: float = 15.6,
|
||||
gamma: float = 0.26,
|
||||
rho: float = 5.0,
|
||||
sigma: float = 0.6,
|
||||
dtype: DTypeLike = jnp.float32,
|
||||
):
|
||||
r"""Initialize the WereRabbit neuron model.
|
||||
|
||||
Args:
|
||||
rtol: Relative tolerance for the spiking fixpoint calculation.
|
||||
atol: Absolute tolerance for the spiking fixpoint calculation.
|
||||
alpha: Current scaling parameter $\alpha = I_{n0}/I_{bias}$ (default: 0.0129)
|
||||
beta: Exponential slope $\beta = \kappa/U_t$ (default: 15.6)
|
||||
gamma: Coupling parameter $\gamma = 26e^{-2}$
|
||||
rho: Steepness of the tanh function $\rho$ (default: 5)
|
||||
sigma: Fixpoint distance scaling $\sigma$ (default: 0.6)
|
||||
dtype: Data type for arrays (default: float32).
|
||||
"""
|
||||
self.dtype = dtype
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.rho = rho
|
||||
self.sigma = sigma
|
||||
|
||||
self.rtol = rtol
|
||||
self.atol = atol
|
||||
|
||||
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
|
||||
"""Initialize the neuron state variables.
|
||||
|
||||
Args:
|
||||
n_neurons: Number of neurons to initialize.
|
||||
|
||||
Returns:
|
||||
Initial state array of shape (neurons, 3) containing [u, v, has_spiked],
|
||||
where u and v are the predator/prey membrane voltages, has_spiked is a
|
||||
variable that is 1 whenever the neuron spike and 0 otherwise .
|
||||
"""
|
||||
x1 = jnp.zeros((n_neurons,), dtype=self.dtype)
|
||||
x2 = jnp.zeros((n_neurons,), dtype=self.dtype)
|
||||
return jnp.stack([x1, x2], axis=1)
|
||||
|
||||
def vector_field(self, y: Float[Array, "neurons 2"]) -> Float[Array, "neurons 2"]:
|
||||
"""Compute vector field of the neuron state variables.
|
||||
|
||||
This implements the WereRabbit dynamics
|
||||
|
||||
- du/dt: Predator dynamics
|
||||
- dv/dt: WerePrey dynamics
|
||||
|
||||
Args:
|
||||
y: State array of shape (neurons, 2) containing [u, v].
|
||||
|
||||
Returns:
|
||||
Time derivatives of shape (neurons, 2) containing [du/dt, dv/dt].
|
||||
"""
|
||||
u = y[:, 0]
|
||||
v = y[:, 1]
|
||||
|
||||
z = jax.nn.tanh(self.rho * (u - v))
|
||||
du = (
|
||||
z * (1 - self.alpha * jnp.exp(self.beta * v) * (1 + self.gamma * (0.5 - u)))
|
||||
- self.sigma
|
||||
)
|
||||
dv = (
|
||||
z
|
||||
* (-1 + self.alpha * jnp.exp(self.beta * u) * (1 + self.gamma * (0.5 - v)))
|
||||
- self.sigma
|
||||
)
|
||||
|
||||
dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)
|
||||
du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)
|
||||
|
||||
return jnp.stack([du, dv], axis=1)
|
||||
|
||||
def dynamics(
|
||||
self,
|
||||
t: float,
|
||||
y: Float[Array, "neurons 2"],
|
||||
args: Dict[str, Any],
|
||||
) -> Float[Array, "neurons 2"]:
|
||||
"""Compute time derivatives of the neuron state variables.
|
||||
|
||||
This implements the WereRabbit dynamics
|
||||
|
||||
- du/dt: Predator dynamics
|
||||
- dv/dt: WerePrey dynamics
|
||||
|
||||
Args:
|
||||
t: Current simulation time (unused but required by framework).
|
||||
y: State array of shape (neurons, 3) containing [u, v, has_spiked].
|
||||
args: Additional arguments (unused but required by framework).
|
||||
|
||||
Returns:
|
||||
Time derivatives of shape (neurons, 3) containing [du/dt, dv/dt, 0].
|
||||
"""
|
||||
dxdt = self.vector_field(y)
|
||||
|
||||
return dxdt
|
||||
|
||||
def spike_condition(
|
||||
self,
|
||||
t: float,
|
||||
y: Float[Array, "neurons 2"],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Float[Array, " neurons"]:
|
||||
"""Compute spike condition for event detection.
|
||||
|
||||
A spike is triggered when the system reach to a fixpoint.
|
||||
|
||||
INFO:
|
||||
`has_spiked` is use to the system don't detect a continuos
|
||||
spike when reach a fixpoint.
|
||||
|
||||
Args:
|
||||
t: Current simulation time (unused but required by the framework).
|
||||
y: State array of shape (neurons, 3) containing [u, v, has_spiked].
|
||||
**kwargs: Additional keyword arguments (unused).
|
||||
|
||||
Returns:
|
||||
Spike condition array of shape (neurons,). Positive values indicate spike.
|
||||
"""
|
||||
_atol = self.atol
|
||||
_rtol = self.rtol
|
||||
_norm = optx.rms_norm
|
||||
|
||||
vf = self.dynamics(t, y, {})
|
||||
|
||||
@jax.vmap
|
||||
def calculate_norm(vf, y):
|
||||
return _atol + _rtol * _norm(y[:-1]) - _norm(vf[:-1])
|
||||
|
||||
base_cond = calculate_norm(vf, y)
|
||||
|
||||
return base_cond
|
||||
64
src/felice/solver.py
Normal file
64
src/felice/solver.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
from diffrax import AbstractSolver, AbstractTerm
|
||||
from diffrax._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, Y
|
||||
from diffrax._solution import RESULTS
|
||||
from diffrax._solver.base import _SolverState
|
||||
from jaxtyping import PyTree
|
||||
|
||||
|
||||
class ClipSolver(eqx.Module):
|
||||
solver: AbstractSolver
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.solver, name)
|
||||
|
||||
def step(
|
||||
self,
|
||||
terms: PyTree[AbstractTerm],
|
||||
t0: RealScalarLike,
|
||||
t1: RealScalarLike,
|
||||
y0: Y,
|
||||
args: Args,
|
||||
solver_state: _SolverState,
|
||||
made_jump: BoolScalarLike,
|
||||
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
|
||||
"""Make a single step of the solver.
|
||||
|
||||
Each step is made over the specified interval $[t_0, t_1]$.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `terms`: The PyTree of terms representing the vector fields and controls.
|
||||
- `t0`: The start of the interval that the step is made over.
|
||||
- `t1`: The end of the interval that the step is made over.
|
||||
- `y0`: The current value of the solution at `t0`.
|
||||
- `args`: Any extra arguments passed to the vector field.
|
||||
- `solver_state`: Any evolving state for the solver itself, at `t0`.
|
||||
- `made_jump`: Whether there was a discontinuity in the vector field at `t0`.
|
||||
Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there
|
||||
are no jumps and for efficiency re-use information between steps; this
|
||||
indicates that a jump has just occurred and this assumption is not true.
|
||||
|
||||
**Returns:**
|
||||
|
||||
A tuple of several objects:
|
||||
|
||||
- The value of the solution at `t1`.
|
||||
- A local error estimate made during the step. (Used by adaptive step size
|
||||
controllers to change the step size.) May be `None` if no estimate was
|
||||
made.
|
||||
- Some dictionary of information that is passed to the solver's interpolation
|
||||
routine to calculate dense output. (Used with `SaveAt(ts=...)` or
|
||||
`SaveAt(dense=...)`.)
|
||||
- The value of the solver state at `t1`.
|
||||
- An integer (corresponding to `diffrax.RESULTS`) indicating whether the step
|
||||
happened successfully, or if (unusually) it failed for some reason.
|
||||
"""
|
||||
y1, y_error, dense_info, solver_state, result = self.solver.step(
|
||||
terms, t0, t1, y0, args, solver_state, made_jump
|
||||
)
|
||||
y1_clipped = jax.tree_util.tree_map(jax.nn.relu, y1)
|
||||
return y1_clipped, y_error, dense_info, solver_state, result
|
||||
Reference in New Issue
Block a user