Initial commit

Co-authored-by: Aradhana Dube <a.dube@rug.nl>
Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl>
Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl>
Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
This commit is contained in:
2026-02-26 18:30:32 +01:00
commit 9fabbdefc0
75 changed files with 447515 additions and 0 deletions

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,247 @@
import marimo
__generated_with = "0.19.4"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
from wigglystuff import Slider2D
return Slider2D, mo
@app.cell
def _():
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
return diffrax, jax, jnp, np, plt
@app.cell
def _(diffrax, jax, jnp):
def vector_field(t, state, args):
u, v = state
alpha, beta, gamma, kappa, sigma, delta = args
z = jax.nn.tanh(kappa * (v - u))
# Prey dynamics
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z
# Predator dynamics
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z
return jnp.array([du, dv])
def compute_nullclines(vector_field, u_range, v_range, args, resolution=200):
"""
Compute nullclines
du/dt = 0 (u-nullcline)
dv/dt = 0 (v-nullcline)
"""
alpha, beta, gamma, kappa, sigma, delta = args
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
U, V = jnp.meshgrid(u_vals, v_vals)
dU, dV = vector_field(0, [U, V], args)
return U, V, dU, dV
def solve(dyn, y0, p, T, n=500):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(dyn),
diffrax.Tsit5(),
t0=0.0,
t1=T,
dt0=0.0001,
y0=y0,
args=p,
saveat=diffrax.SaveAt(ts=jnp.linspace(0, T, n)),
stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-8),
max_steps=50000,
)
return sol.ts, sol.ys
return compute_nullclines, solve, vector_field
@app.cell
def _(Slider2D, mo):
# alpha = 0.5 # I_n0 / I_bias ratio
# beta = 0.39/0.025 # k / U_t (inverse thermal scale)
# gamma = 0.26 # coupling coefficient
# kappa = 5.0 # tanh steepness
# sigma = 0.6 # bias scaling (s * I_bias normalized)
# y0 = jnp.array([0.2, 0.4])
# ts, ys = solve(vector_field, y0, params, 140)
alpha = mo.ui.slider(
0.0004, 0.012, 0.00001, 0.00129, label="alpha", orientation="vertical"
)
beta = mo.ui.slider(
0.0, 30, 0.00001, 0.39 / 0.025, label="beta", orientation="vertical"
)
gamma = mo.ui.slider(0, 1, 0.01, 0.26, label="gamma", orientation="vertical")
kappa = mo.ui.slider(0, 30, 1.0, 10.0, label="kappa", orientation="vertical")
sigma = mo.ui.slider(0, 1, 0.01, 0.6, label="sigma", orientation="vertical")
delta = mo.ui.slider(1, 100.0, 1, 10, label="delta", orientation="vertical")
# v0 = mo.ui.slider(0, 1.0, 0.01, 0.3, label="v0")
# u0 = mo.ui.slider(0, 1.0, 0.01, 0.2, label="u0", orientation="vertical")
state0 = mo.ui.anywidget(
Slider2D(
x=0.34,
y=0.38,
width=150,
height=150,
x_bounds=(0.0, 0.6),
y_bounds=(0.0, 0.6),
)
)
mo.hstack(
[
mo.plain_text("""
alpha: I_n0 / I_bias ratio
beta: k / U_t ratio
gamma: coupling coefficient
kappa: tanh steepness
sigma: bias scaling (s * I_bias)
"""),
mo.hstack(
[state0, alpha, beta, gamma, kappa, sigma, delta], justify="start"
),
]
)
return alpha, beta, delta, gamma, kappa, sigma, state0
@app.cell
def _(
alpha,
beta,
compute_nullclines,
delta,
gamma,
jnp,
kappa,
np,
plt,
sigma,
solve,
state0,
vector_field,
):
params = (
alpha.value,
beta.value,
gamma.value,
kappa.value,
sigma.value,
delta.value,
)
ic_neuro = [state0.x, state0.y]
u_range = [0.0, 0.6]
v_range = [0.0, 0.6]
u_sparse = jnp.linspace(u_range[0], u_range[1], 20)
v_sparse = jnp.linspace(v_range[0], v_range[1], 20)
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
def plot_vf(ax, vector_field):
U, V, dU, dV = compute_nullclines(vector_field, u_range, v_range, params)
dUs, dVs = vector_field(0, [Us, Vs], params)
# Normalize for visualization
magnitude = np.sqrt(dUs**2 + dVs**2)
magnitude[magnitude == 0] = 1
dUs_norm = dUs / magnitude
dVs_norm = dVs / magnitude
# Nullclines
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=2, linestyles="-")
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=2, linestyles="-")
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
# Trajectories
color = plt.cm.plasma(0.2)
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, delta.value)
ax.plot(ys[:, 0], ys[:, 1], "-", color=color, linewidth=1.5, alpha=0.8)
ax.plot(ic_neuro[0], ic_neuro[1], "o", color=color, markersize=6)
ax.set_xlabel("u (Prey)")
ax.set_ylabel("v (Predator)")
ax.set_title("Wererabbit: Phase Portrait")
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
ax.set_xlim(u_range)
ax.set_ylim(v_range)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
def plot_trj(ax, vector_field):
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, delta.value)
ax.plot(ts, ys[:, 0], "b-", linewidth=2, label="u (Prey)")
ax.plot(ts, ys[:, 1], "r-", linewidth=2, label="v (Predator)")
ax.set_xlabel("Time τ")
ax.set_ylabel("Population")
ax.set_title(
f"Wererabbit: Time Series (IC: u₀={ic_neuro[0]:.2f}, v₀={ic_neuro[1]:.2f})"
)
ax.legend()
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
fig = plt.figure(figsize=(10, 4))
# --- Plot 1: Wererabbit Phase Portrait ---
ax1 = fig.add_subplot(1, 2, 1)
plot_vf(ax1, vector_field)
ax2 = fig.add_subplot(1, 2, 2)
plot_trj(ax2, vector_field)
# ax3 = fig.add_subplot(3, 2, 3)
# plot_vf(ax3, vector_field_prod)
# ax4 = fig.add_subplot(3, 2, 4)
# plot_trj(ax4, vector_field_prod)
# ax5 = fig.add_subplot(3, 2, 5)
# plot_vf(ax5, vector_field_exp)
# ax6 = fig.add_subplot(3, 2, 6)
# plot_trj(ax6, vector_field_exp)
plt.tight_layout()
fig
return
@app.cell
def _():
return
@app.cell
def _():
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,188 @@
import marimo
__generated_with = "0.19.4"
app = marimo.App(width="medium")
@app.cell
def _():
import diffrax as dfx
import jax.numpy as jnp
import marimo as mo
import matplotlib.pyplot as plt
import numpy as np
from jax import jit
return dfx, jit, jnp, mo, np, plt
@app.cell
def _(dfx, jnp):
def vector_field(t, y, args):
v, vslow = y
ipasive = args["gmax"] * (v - args["Erev"])
ifast = args["af"] * jnp.tanh(v - args["Ef"])
islow = args["as"] * jnp.tanh(vslow - args["Es"])
dv = (-ipasive - ifast - islow) / args["C"]
dvs = (v - vslow) / args["ts"]
return jnp.array([dv, dvs])
term = dfx.ODETerm(vector_field)
return term, vector_field
@app.cell
def _(mo):
p1 = mo.ui.slider(0.0, 5.0, value=1.0, step=0.1, label="gmax")
p2 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.1, label="Erev")
p3 = mo.ui.slider(-5.0, 5.0, value=-2.0, step=0.05, label="af")
p4 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Ef")
p5 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.05, label="as")
p6 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Es")
p7 = mo.ui.slider(1.0, 100.0, value=50.0, step=0.1, label="ts")
p8 = mo.ui.slider(0.0, 1.0, value=1.0, step=0.01, label="C")
mo.hstack(
[
mo.vstack([p1, p2, p3, p4], justify="start", gap=1),
mo.vstack([p5, p6, p7, p8], justify="start", gap=1),
]
)
return p1, p2, p3, p4, p5, p6, p7, p8
@app.cell
def _(mo):
mo.md("""
### Initial Conditions & Simulation
""")
return
@app.cell
def _(mo):
x0 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.1, label="x₀")
y0 = mo.ui.slider(-5.0, 5.0, value=0.0, step=0.1, label="y₀")
t_max = mo.ui.slider(10, 100, value=30, step=5, label="t_max")
mo.hstack([x0, y0, t_max], justify="start", gap=2)
return t_max, x0, y0
@app.cell
def _(dfx, jit, jnp, p1, p2, p3, p4, p5, p6, p7, p8, t_max, term, x0, y0):
@jit
def solve_ode(y_init, args, t_end):
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=jnp.linspace(0, t_end, 2000))
sol = dfx.diffeqsolve(
term,
solver,
t0=0,
t1=t_end,
dt0=0.01,
y0=y_init,
args=args,
saveat=saveat,
max_steps=100000,
)
return sol.ts, sol.ys
args = {
"gmax": p1.value,
"Erev": p2.value,
"af": p3.value,
"Ef": p4.value,
"as": p5.value,
"Es": p6.value,
"ts": p7.value,
"C": p8.value,
}
y_init = jnp.array([x0.value, y0.value])
t, ys = solve_ode(y_init, args, float(t_max.value))
x_sol = ys[:, 0]
y_sol = ys[:, 1]
return args, t, x_sol, y_sol
@app.cell
def _(args, jnp, np, plt, t, vector_field, x0, x_sol, y0, y_sol):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Time series
ax1.plot(t, x_sol, "b-", lw=1.5, label="x(t)")
ax1.plot(t, y_sol, "r-", lw=1.5, label="y(t)")
ax1.set_xlabel("Time t")
ax1.set_ylabel("State")
ax1.set_title("Transient Response")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Phase plane bounds
# pad = 1.0
xmin, xmax = -4, 4
ymin, ymax = -2.5, 2.5
# Vector field
X, Y = jnp.meshgrid(jnp.linspace(xmin, xmax, 20), jnp.linspace(ymin, ymax, 20))
U, V = jnp.zeros_like(X), np.zeros_like(Y)
state = jnp.stack([X, Y], axis=0)
deriv = vector_field(0.0, state, args)
dx, dy = deriv[0], deriv[1]
mag = jnp.sqrt(dx**2 + dy**2)
U = jnp.where(mag > 0, dx / mag, U)
V = jnp.where(mag > 0, dy / mag, V)
ax2.quiver(X, Y, U, V, alpha=0.4, color="gray", scale=25)
# Nullclines
Xf, Yf = jnp.meshgrid(jnp.linspace(xmin, xmax, 150), jnp.linspace(ymin, ymax, 150))
DX, DY = jnp.zeros_like(Xf), jnp.zeros_like(Yf)
state = jnp.stack([Xf, Yf], axis=0)
deriv = vector_field(0.0, state, args)
DX, DY = deriv[0], deriv[1]
ax2.contour(
Xf,
Yf,
DX,
levels=[0],
colors="blue",
linestyles="--",
linewidths=1.5,
alpha=0.7,
)
ax2.contour(
Xf, Yf, DY, levels=[0], colors="red", linestyles="--", linewidths=1.5, alpha=0.7
)
# Trajectory
ax2.plot(x_sol, y_sol, "b-", lw=2)
ax2.plot(x0.value, y0.value, "go", ms=10, label="Start")
ax2.plot(x_sol[-1], y_sol[-1], "r*", ms=12, label="End")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_title("Phase Plane")
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim(xmin, xmax)
ax2.set_ylim(ymin, ymax)
plt.tight_layout()
fig
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,318 @@
import marimo
__generated_with = "0.19.4"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
from wigglystuff import Slider2D
return Slider2D, mo
@app.cell
def _():
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
return diffrax, jax, jnp, np, plt
@app.cell
def _(diffrax, jax, jnp):
def vector_field(t, state, args):
u, v = state
alpha, beta, gamma, kappa, sigma, delta = args
z = jax.nn.tanh(kappa * (u - v))
# Prey dynamics
du = z * (1 - alpha * jnp.exp(beta * v) * (1 + gamma * (0.5 - u))) - sigma
# Predator dynamics
dv = z * (-1 + alpha * jnp.exp(beta * u) * u * (1 + gamma * (0.5 - v))) - sigma
return jnp.array([du, dv])
def vector_field_prod(t, state, args):
u, v = state
alpha, beta, gamma, kappa, sigma, delta = args
z = jax.nn.tanh(kappa * (u - v))
# Prey dynamics
du = (
z * (1 - alpha * jnp.exp(beta * v) * (1 + gamma * (0.5 - u))) - sigma
# + sigma * jnp.maximum(0, delta - u) / (delta + 1e-16)
)
# Predator dynamics
dv = (
z * (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.5 - v))) - sigma
# + sigma * jnp.maximum(0, delta - v) / (delta + 1e-16)
)
dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)
du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)
return jnp.array([du, dv])
def vector_field_exp(t, state, args):
u, v = state
alpha, beta, gamma, kappa, sigma, delta = args
z = jax.nn.tanh(kappa * (u - v))
# Prey dynamics
du = (
z * (1 - alpha * jnp.exp(beta * v) * (1 + gamma * (0.5 - u)))
- sigma
+ sigma * jnp.exp(-u / delta)
)
# Predator dynamics
dv = (
z * (-1 + alpha * jnp.exp(beta * u) * u * (1 + gamma * (0.5 - v)))
- sigma
+ sigma * jnp.exp(-v / delta)
)
return jnp.array([du, dv])
def physical_vector_field(t, state, args):
x1, x2 = state
alpha, beta, gamma, kappa, sigma, delta = args
In0 = 129e-15 # fixed by design
C = 0.1e-12 # fixed by design
kk = 0.39 # fixed by tech
Ut = 0.025 # temperature dependent
Ibias = In0 / alpha
Ia = Ibias * sigma
x3 = jax.nn.tanh(kappa * (x1 - x2))
dx1 = (
x3 * Ibias
- (In0 * jnp.exp(kk * x2 / Ut)) * (x3 + 26e-2 * (0.5 - x1) * x3)
- Ia
) / C
dx2 = (
-x3 * Ibias
+ In0 * jnp.exp(kk * x1 / Ut) * (x3 + 26e-2 * (0.5 - x2) * x3)
- Ia
) / C
return jnp.array([dx1, dx2])
def compute_nullclines(vector_field, u_range, v_range, args, resolution=200):
"""
Compute nullclines
du/dt = 0 (u-nullcline)
dv/dt = 0 (v-nullcline)
"""
alpha, beta, gamma, kappa, sigma, delta = args
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
U, V = jnp.meshgrid(u_vals, v_vals)
dU, dV = vector_field(0, [U, V], args)
return U, V, dU, dV
def solve(dyn, y0, p, T, n=1000):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(dyn),
diffrax.Tsit5(),
t0=0.0,
t1=T,
dt0=0.01,
y0=y0,
args=p,
saveat=diffrax.SaveAt(ts=jnp.linspace(0, T, n)),
stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-8),
max_steps=50000,
)
return sol.ts, sol.ys
return compute_nullclines, solve, vector_field_prod
@app.cell
def _(Slider2D, mo):
# alpha = 0.5 # I_n0 / I_bias ratio
# beta = 0.39/0.025 # k / U_t (inverse thermal scale)
# gamma = 0.26 # coupling coefficient
# kappa = 5.0 # tanh steepness
# sigma = 0.6 # bias scaling (s * I_bias normalized)
# y0 = jnp.array([0.2, 0.4])
# ts, ys = solve(vector_field, y0, params, 140)
alpha = mo.ui.slider(
0.0004, 0.012, 0.00001, 0.00129, label="alpha", orientation="vertical"
)
beta = mo.ui.slider(
0.0, 30, 0.00001, 0.39 / 0.025, label="beta", orientation="vertical"
)
gamma = mo.ui.slider(0, 1, 0.01, 0.26, label="gamma", orientation="vertical")
kappa = mo.ui.slider(0, 10, 0.1, 5.0, label="kappa", orientation="vertical")
sigma = mo.ui.slider(0, 1, 0.01, 0.6, label="sigma", orientation="vertical")
delta = mo.ui.slider(0, 0.1, 0.001, 0.02, label="delta", orientation="vertical")
# v0 = mo.ui.slider(0, 1.0, 0.01, 0.3, label="v0")
# u0 = mo.ui.slider(0, 1.0, 0.01, 0.2, label="u0", orientation="vertical")
state0 = mo.ui.anywidget(
Slider2D(
width=150,
height=150,
x_bounds=(-1.0, 1.5),
y_bounds=(-1.0, 1.5),
)
)
mo.hstack(
[
mo.plain_text("""
alpha: I_n0 / I_bias ratio
beta: k / U_t ratio
gamma: coupling coefficient
kappa: tanh steepness
sigma: bias scaling (s * I_bias)
"""),
mo.hstack(
[state0, alpha, beta, gamma, kappa, sigma, delta], justify="start"
),
]
)
return alpha, beta, delta, gamma, kappa, sigma, state0
@app.cell
def _(
alpha,
beta,
compute_nullclines,
delta,
gamma,
jnp,
kappa,
np,
plt,
sigma,
solve,
state0,
vector_field_prod,
):
params = (
alpha.value,
beta.value,
gamma.value,
kappa.value,
sigma.value,
delta.value,
)
ic_neuro = [state0.x, state0.y]
u_range = [-1.0, 1.5]
v_range = [-1.0, 1.5]
u_sparse = jnp.linspace(u_range[0], u_range[1], 20)
v_sparse = jnp.linspace(v_range[0], v_range[1], 20)
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
def plot_vf(ax, vector_field):
U, V, dU, dV = compute_nullclines(vector_field, u_range, v_range, params)
dUs, dVs = vector_field(0, [Us, Vs], params)
# Normalize for visualization
magnitude = np.sqrt(dUs**2 + dVs**2)
magnitude[magnitude == 0] = 1
dUs_norm = dUs / magnitude
dVs_norm = dVs / magnitude
# Nullclines
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=2, linestyles="-")
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=2, linestyles="-")
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
# Trajectories
color = plt.cm.plasma(0.2)
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, 50)
ax.plot(ys[:, 0], ys[:, 1], "-", color=color, linewidth=1.5, alpha=0.8)
ax.plot(ic_neuro[0], ic_neuro[1], "o", color=color, markersize=6)
ax.set_xlabel("u (Prey)")
ax.set_ylabel("v (Predator)")
ax.set_title("Wererabbit: Phase Portrait")
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
ax.set_xlim(u_range)
ax.set_ylim(v_range)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
def plot_trj(ax, vector_field):
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, 50)
ax.plot(ts, ys[:, 0], "b-", linewidth=2, label="u (Prey)")
ax.plot(ts, ys[:, 1], "r-", linewidth=2, label="v (Predator)")
ax.set_xlabel("Time τ")
ax.set_ylabel("Population")
ax.set_title(
f"Wererabbit: Time Series (IC: u₀={ic_neuro[0]:.2f}, v₀={ic_neuro[1]:.2f})"
)
ax.legend()
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
fig = plt.figure(figsize=(10, 4))
# --- Plot 1: Wererabbit Phase Portrait ---
ax1 = fig.add_subplot(1, 2, 1)
plot_vf(ax1, vector_field_prod)
ax2 = fig.add_subplot(1, 2, 2)
plot_trj(ax2, vector_field_prod)
# ax3 = fig.add_subplot(3, 2, 3)
# plot_vf(ax3, vector_field_prod)
# ax4 = fig.add_subplot(3, 2, 4)
# plot_trj(ax4, vector_field_prod)
# ax5 = fig.add_subplot(3, 2, 5)
# plot_vf(ax5, vector_field_exp)
# ax6 = fig.add_subplot(3, 2, 6)
# plot_trj(ax6, vector_field_exp)
plt.tight_layout()
fig
return
@app.cell
def _():
return
@app.cell
def _():
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

158
scripts/networks/plot.ipynb Normal file

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

281
scripts/networks/train.py Normal file
View File

@@ -0,0 +1,281 @@
import argparse
import os
from typing import Any, Tuple
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
import pandas as pd
from jaxtyping import Array, Float
from optax import OptState
from tqdm import trange
from felice.datasets.reasoning import ReasoningDataset
from felice.networks import Implicit, Mamba, SequenceClassifier
from felice.networks.implicit.boomerang import ImplicitBoomerang
def compute_loss(
model: eqx.Module, inputs: Array, targets: Array, masks: Array
) -> Float[Array, ""]:
def forward_single(inp, tgt, msk):
logits = model(inp)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, tgt)
return (loss * msk).sum() / (msk.sum() + 1e-8)
losses = jax.vmap(forward_single)(inputs, targets, masks)
return losses.mean()
v_and_grad = eqx.filter_value_and_grad(compute_loss)
@eqx.filter_jit
def compute_accuracy(
model: eqx.Module, inputs: Array, targets: Array, masks: Array
) -> Float[Array, ""]:
def forward_single(inp, tgt, msk):
logits = model(inp)
preds = jnp.argmax(logits, axis=-1)
correct = (preds == tgt) * msk
return correct.sum(), msk.sum()
correct, total = jax.vmap(forward_single)(inputs, targets, masks)
return correct.sum() / (total.sum() + 1e-8)
@eqx.filter_jit
def train_step(
model: eqx.Module,
opt_state: OptState,
optimizer: Any,
inputs: Array,
targets: Array,
masks: Array,
) -> Tuple[eqx.Module, OptState, Array]:
loss, grads = v_and_grad(model, inputs, targets, masks)
updates, opt_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
def train_and_compare(
model_type: Any,
logdir: str,
task_type: str = "simple",
n_epochs: int = 1000,
batch_size: int = 64,
d_model: int = 64,
d_state: int = 16,
d_inner: int = 32,
dt: float = 1.0,
max_iters: int = 8,
lr: float = 1e-3,
seed: int = 42,
# with_thr: bool = True,
) -> Tuple[eqx.Module, eqx.Module, Array, Array, pd.DataFrame]:
r"""Train Mamba and implicit model on the reasoning synthetic dataset.
Args:
model_type: The type of the implicit model to train (Boomerang, Mamba Implicit).
logdir: Directory and filenmae of the log.
task_type: Type of task to solve from the reasoning synthetic dataset (simple, accumulation).
n_epochs: Number of epochs to train.
batch_size: Training batch size.
d_model: Model dimensions including output.
d_state: Model state dimension.
d_inner: Model latent dimension.
max_iters: Maximum number of iterations in the implicit model.
lr: Learning rate.
seed: Random seed.
with_thr: For the Boomerang model, if using threshold for dual fixpoints.
Returns:
The trained models (mamba and implicit) with the respective final accuracy and
a pandas dataframe with the loss and accuracy per epoch.
"""
key = jrandom.key(seed)
keys = jrandom.split(key, 4)
dataset = ReasoningDataset()
standard_model = SequenceClassifier(
vocab_size=dataset.VOCAB_SIZE,
d_model=d_model,
d_state=d_state,
d_inner=d_inner,
model_class=Mamba,
key=keys[0],
)
implicit_model = SequenceClassifier(
vocab_size=dataset.VOCAB_SIZE,
d_model=d_model,
d_state=d_state,
d_inner=d_inner,
model_class=model_type,
max_iters=max_iters,
dt=dt,
# with_thr=with_thr,
key=keys[1],
)
# implicit_model = ImplicitBoomerang(
# vocab_size=dataset.VOCAB_SIZE,
# d_model=d_model,
# d_state=d_state,
# d_inner=d_inner,
# max_iters=max_iters,
# dt=dt,
# # with_thr=with_thr,
# key=keys[1],
# )
# Count parameters
def count_params(model):
return sum(
x.size for x in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array))
)
print(f"Mamba SSM params: {count_params(standard_model):,}")
print(f"Implicit SSM params: {count_params(implicit_model):,}")
# Optimizers
optimizer = optax.adam(lr)
standard_opt_state = optimizer.init(eqx.filter(standard_model, eqx.is_array))
implicit_opt_state = optimizer.init(eqx.filter(implicit_model, eqx.is_array))
# Training loop
print(f"\nTraining on task: {task_type} with {max_iters} steps")
print("=" * 60)
train_key = keys[2]
df = pd.DataFrame({"Epoch": [], "Loss": [], "Acc": [], "Model": []})
pbar = trange(n_epochs)
for epoch in pbar:
train_key, batch_key = jrandom.split(train_key)
inputs, targets, masks = dataset.generate_batch(
batch_key, batch_size, task_type
)
# Train standard model
standard_model, standard_opt_state, standard_loss = train_step(
standard_model, standard_opt_state, optimizer, inputs, targets, masks
)
# Train implicit model
implicit_model, implicit_opt_state, implicit_loss = train_step(
implicit_model, implicit_opt_state, optimizer, inputs, targets, masks
)
if (epoch + 1) % 10 == 0:
# Evaluate on fresh batch
eval_key = jrandom.fold_in(keys[3], epoch)
eval_inputs, eval_targets, eval_masks = dataset.generate_batch(
eval_key, batch_size, task_type
)
standard_acc = compute_accuracy(
standard_model, eval_inputs, eval_targets, eval_masks
)
implicit_acc = compute_accuracy(
implicit_model, eval_inputs, eval_targets, eval_masks
)
new_df = pd.DataFrame(
{
"Epoch": [epoch, epoch],
"Loss": [standard_loss.item(), implicit_loss.item()],
"Acc": [standard_acc.item(), implicit_acc.item()],
"Model": ["Mamba", "Implicit"],
}
)
df = pd.concat([df, new_df], ignore_index=True)
df.to_pickle(logdir)
pbar.write(
f"Epoch {epoch + 1:4d} | "
f"Standard: loss={standard_loss:.4f}, acc={standard_acc:.4f} | "
f"Implicit: loss={implicit_loss:.4f}, acc={implicit_acc:.4f}"
)
# Final evaluation
print("\n" + "=" * 60)
print("Final Evaluation (1000 samples)")
print("=" * 60)
eval_inputs, eval_targets, eval_masks = dataset.generate_batch(
keys[4], 1000, task_type
)
standard_acc = compute_accuracy(
standard_model, eval_inputs, eval_targets, eval_masks
)
implicit_acc = compute_accuracy(
implicit_model, eval_inputs, eval_targets, eval_masks
)
print(f"Mamba SSM accuracy: {standard_acc:.4f}")
print(f"Implicit SSM accuracy: {implicit_acc:.4f}")
print(f"Improvement: {(implicit_acc - standard_acc) * 100:.2f}%")
return standard_model, implicit_model, standard_acc, implicit_acc, df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-t", type=int, choices=[1, 2], default=1, help="Task to perform"
)
parser.add_argument(
"-m",
type=str,
choices=["boomerang", "implicit"],
default="implicit",
help="Neuron model to use",
)
parser.add_argument("--dt", type=float, default=0.001, help="Simulation timestep")
parser.add_argument("-i", type=int, default=8, help="Maximum number of iterations")
parser.add_argument("-b", type=int, default=64, help="Batch size")
parser.add_argument(
"--thr", action="store_true", help="Using threshold on the boomerang neuron"
)
args = parser.parse_args()
if args.m == "boomerang":
model_type = ImplicitBoomerang
elif args.m == "implicit":
model_type = Implicit
else:
raise NotImplementedError(f"{args.t} model type not implemented")
logdir = os.path.join("tmp", f"task{args.t}-{args.i}-{args.m}-{args.thr}")
if not os.path.exists("tmp"):
os.makedirs("tmp")
print(f"Saving at {logdir}")
_, implicit_model, std_acc1, imp_acc1, df = train_and_compare(
model_type,
logdir,
task_type="simple" if args.t == 1 else "accumulation",
n_epochs=1000,
batch_size=64,
d_model=ReasoningDataset.NUM_OUTPUT,
d_state=16,
d_inner=128,
dt=args.dt,
max_iters=args.i,
# with_thr=args.thr,
)
eqx.tree_serialise_leaves(f"{logdir}.eqx", implicit_model)
df.to_pickle(logdir)
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"{'Task':<25} {'Mamba SSM':<15} {'Implicit SSM':<15} {'Delta':<10}")
print("-" * 70)
print(
f"{'Simple Comparison':<25} {std_acc1:<15.4f} {imp_acc1:<15.4f} {(imp_acc1 - std_acc1) * 100:>+.2f}%"
)

File diff suppressed because one or more lines are too long