mirror of
https://github.com/bics-rug/felice-models.git
synced 2026-03-10 13:07:40 +01:00
Initial commit
Co-authored-by: Aradhana Dube <a.dube@rug.nl> Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl> Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl> Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
This commit is contained in:
437049
scripts/examples/neuron_models/boomerang/boomerang.ipynb
Normal file
437049
scripts/examples/neuron_models/boomerang/boomerang.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
scripts/examples/neuron_models/boomerang/boomerang.mp4
Normal file
BIN
scripts/examples/neuron_models/boomerang/boomerang.mp4
Normal file
Binary file not shown.
247
scripts/examples/neuron_models/boomerang/boomerang.py
Normal file
247
scripts/examples/neuron_models/boomerang/boomerang.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.19.4"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import marimo as mo
|
||||
from wigglystuff import Slider2D
|
||||
|
||||
return Slider2D, mo
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import diffrax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
return diffrax, jax, jnp, np, plt
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(diffrax, jax, jnp):
|
||||
def vector_field(t, state, args):
|
||||
u, v = state
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
|
||||
z = jax.nn.tanh(kappa * (v - u))
|
||||
|
||||
# Prey dynamics
|
||||
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z
|
||||
|
||||
# Predator dynamics
|
||||
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z
|
||||
|
||||
return jnp.array([du, dv])
|
||||
|
||||
def compute_nullclines(vector_field, u_range, v_range, args, resolution=200):
|
||||
"""
|
||||
Compute nullclines
|
||||
du/dt = 0 (u-nullcline)
|
||||
dv/dt = 0 (v-nullcline)
|
||||
"""
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
|
||||
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
|
||||
U, V = jnp.meshgrid(u_vals, v_vals)
|
||||
|
||||
dU, dV = vector_field(0, [U, V], args)
|
||||
|
||||
return U, V, dU, dV
|
||||
|
||||
def solve(dyn, y0, p, T, n=500):
|
||||
sol = diffrax.diffeqsolve(
|
||||
diffrax.ODETerm(dyn),
|
||||
diffrax.Tsit5(),
|
||||
t0=0.0,
|
||||
t1=T,
|
||||
dt0=0.0001,
|
||||
y0=y0,
|
||||
args=p,
|
||||
saveat=diffrax.SaveAt(ts=jnp.linspace(0, T, n)),
|
||||
stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-8),
|
||||
max_steps=50000,
|
||||
)
|
||||
return sol.ts, sol.ys
|
||||
|
||||
return compute_nullclines, solve, vector_field
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(Slider2D, mo):
|
||||
# alpha = 0.5 # I_n0 / I_bias ratio
|
||||
# beta = 0.39/0.025 # k / U_t (inverse thermal scale)
|
||||
# gamma = 0.26 # coupling coefficient
|
||||
# kappa = 5.0 # tanh steepness
|
||||
# sigma = 0.6 # bias scaling (s * I_bias normalized)
|
||||
# y0 = jnp.array([0.2, 0.4])
|
||||
# ts, ys = solve(vector_field, y0, params, 140)
|
||||
|
||||
alpha = mo.ui.slider(
|
||||
0.0004, 0.012, 0.00001, 0.00129, label="alpha", orientation="vertical"
|
||||
)
|
||||
beta = mo.ui.slider(
|
||||
0.0, 30, 0.00001, 0.39 / 0.025, label="beta", orientation="vertical"
|
||||
)
|
||||
gamma = mo.ui.slider(0, 1, 0.01, 0.26, label="gamma", orientation="vertical")
|
||||
kappa = mo.ui.slider(0, 30, 1.0, 10.0, label="kappa", orientation="vertical")
|
||||
sigma = mo.ui.slider(0, 1, 0.01, 0.6, label="sigma", orientation="vertical")
|
||||
delta = mo.ui.slider(1, 100.0, 1, 10, label="delta", orientation="vertical")
|
||||
|
||||
# v0 = mo.ui.slider(0, 1.0, 0.01, 0.3, label="v0")
|
||||
# u0 = mo.ui.slider(0, 1.0, 0.01, 0.2, label="u0", orientation="vertical")
|
||||
|
||||
state0 = mo.ui.anywidget(
|
||||
Slider2D(
|
||||
x=0.34,
|
||||
y=0.38,
|
||||
width=150,
|
||||
height=150,
|
||||
x_bounds=(0.0, 0.6),
|
||||
y_bounds=(0.0, 0.6),
|
||||
)
|
||||
)
|
||||
|
||||
mo.hstack(
|
||||
[
|
||||
mo.plain_text("""
|
||||
alpha: I_n0 / I_bias ratio
|
||||
beta: k / U_t ratio
|
||||
gamma: coupling coefficient
|
||||
kappa: tanh steepness
|
||||
sigma: bias scaling (s * I_bias)
|
||||
"""),
|
||||
mo.hstack(
|
||||
[state0, alpha, beta, gamma, kappa, sigma, delta], justify="start"
|
||||
),
|
||||
]
|
||||
)
|
||||
return alpha, beta, delta, gamma, kappa, sigma, state0
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(
|
||||
alpha,
|
||||
beta,
|
||||
compute_nullclines,
|
||||
delta,
|
||||
gamma,
|
||||
jnp,
|
||||
kappa,
|
||||
np,
|
||||
plt,
|
||||
sigma,
|
||||
solve,
|
||||
state0,
|
||||
vector_field,
|
||||
):
|
||||
params = (
|
||||
alpha.value,
|
||||
beta.value,
|
||||
gamma.value,
|
||||
kappa.value,
|
||||
sigma.value,
|
||||
delta.value,
|
||||
)
|
||||
ic_neuro = [state0.x, state0.y]
|
||||
|
||||
u_range = [0.0, 0.6]
|
||||
v_range = [0.0, 0.6]
|
||||
|
||||
u_sparse = jnp.linspace(u_range[0], u_range[1], 20)
|
||||
v_sparse = jnp.linspace(v_range[0], v_range[1], 20)
|
||||
|
||||
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
|
||||
|
||||
def plot_vf(ax, vector_field):
|
||||
U, V, dU, dV = compute_nullclines(vector_field, u_range, v_range, params)
|
||||
dUs, dVs = vector_field(0, [Us, Vs], params)
|
||||
|
||||
# Normalize for visualization
|
||||
magnitude = np.sqrt(dUs**2 + dVs**2)
|
||||
magnitude[magnitude == 0] = 1
|
||||
dUs_norm = dUs / magnitude
|
||||
dVs_norm = dVs / magnitude
|
||||
|
||||
# Nullclines
|
||||
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=2, linestyles="-")
|
||||
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=2, linestyles="-")
|
||||
|
||||
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
|
||||
|
||||
# Trajectories
|
||||
color = plt.cm.plasma(0.2)
|
||||
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, delta.value)
|
||||
ax.plot(ys[:, 0], ys[:, 1], "-", color=color, linewidth=1.5, alpha=0.8)
|
||||
ax.plot(ic_neuro[0], ic_neuro[1], "o", color=color, markersize=6)
|
||||
|
||||
ax.set_xlabel("u (Prey)")
|
||||
ax.set_ylabel("v (Predator)")
|
||||
ax.set_title("Wererabbit: Phase Portrait")
|
||||
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
|
||||
ax.set_xlim(u_range)
|
||||
ax.set_ylim(v_range)
|
||||
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
|
||||
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
|
||||
|
||||
def plot_trj(ax, vector_field):
|
||||
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, delta.value)
|
||||
ax.plot(ts, ys[:, 0], "b-", linewidth=2, label="u (Prey)")
|
||||
ax.plot(ts, ys[:, 1], "r-", linewidth=2, label="v (Predator)")
|
||||
|
||||
ax.set_xlabel("Time τ")
|
||||
ax.set_ylabel("Population")
|
||||
ax.set_title(
|
||||
f"Wererabbit: Time Series (IC: u₀={ic_neuro[0]:.2f}, v₀={ic_neuro[1]:.2f})"
|
||||
)
|
||||
ax.legend()
|
||||
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
|
||||
|
||||
fig = plt.figure(figsize=(10, 4))
|
||||
|
||||
# --- Plot 1: Wererabbit Phase Portrait ---
|
||||
ax1 = fig.add_subplot(1, 2, 1)
|
||||
plot_vf(ax1, vector_field)
|
||||
|
||||
ax2 = fig.add_subplot(1, 2, 2)
|
||||
plot_trj(ax2, vector_field)
|
||||
|
||||
# ax3 = fig.add_subplot(3, 2, 3)
|
||||
# plot_vf(ax3, vector_field_prod)
|
||||
|
||||
# ax4 = fig.add_subplot(3, 2, 4)
|
||||
# plot_trj(ax4, vector_field_prod)
|
||||
|
||||
# ax5 = fig.add_subplot(3, 2, 5)
|
||||
# plot_vf(ax5, vector_field_exp)
|
||||
|
||||
# ax6 = fig.add_subplot(3, 2, 6)
|
||||
# plot_trj(ax6, vector_field_exp)
|
||||
|
||||
plt.tight_layout()
|
||||
fig
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
240
scripts/examples/neuron_models/fhn/fhnrs.ipynb
Normal file
240
scripts/examples/neuron_models/fhn/fhnrs.ipynb
Normal file
File diff suppressed because one or more lines are too long
188
scripts/examples/neuron_models/fhn/fhnrs.py
Normal file
188
scripts/examples/neuron_models/fhn/fhnrs.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.19.4"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import diffrax as dfx
|
||||
import jax.numpy as jnp
|
||||
import marimo as mo
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from jax import jit
|
||||
|
||||
return dfx, jit, jnp, mo, np, plt
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dfx, jnp):
|
||||
def vector_field(t, y, args):
|
||||
v, vslow = y
|
||||
|
||||
ipasive = args["gmax"] * (v - args["Erev"])
|
||||
ifast = args["af"] * jnp.tanh(v - args["Ef"])
|
||||
islow = args["as"] * jnp.tanh(vslow - args["Es"])
|
||||
|
||||
dv = (-ipasive - ifast - islow) / args["C"]
|
||||
dvs = (v - vslow) / args["ts"]
|
||||
|
||||
return jnp.array([dv, dvs])
|
||||
|
||||
term = dfx.ODETerm(vector_field)
|
||||
return term, vector_field
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
p1 = mo.ui.slider(0.0, 5.0, value=1.0, step=0.1, label="gmax")
|
||||
p2 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.1, label="Erev")
|
||||
p3 = mo.ui.slider(-5.0, 5.0, value=-2.0, step=0.05, label="af")
|
||||
p4 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Ef")
|
||||
p5 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.05, label="as")
|
||||
p6 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Es")
|
||||
p7 = mo.ui.slider(1.0, 100.0, value=50.0, step=0.1, label="ts")
|
||||
p8 = mo.ui.slider(0.0, 1.0, value=1.0, step=0.01, label="C")
|
||||
|
||||
mo.hstack(
|
||||
[
|
||||
mo.vstack([p1, p2, p3, p4], justify="start", gap=1),
|
||||
mo.vstack([p5, p6, p7, p8], justify="start", gap=1),
|
||||
]
|
||||
)
|
||||
return p1, p2, p3, p4, p5, p6, p7, p8
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
mo.md("""
|
||||
### Initial Conditions & Simulation
|
||||
""")
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
x0 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.1, label="x₀")
|
||||
y0 = mo.ui.slider(-5.0, 5.0, value=0.0, step=0.1, label="y₀")
|
||||
t_max = mo.ui.slider(10, 100, value=30, step=5, label="t_max")
|
||||
|
||||
mo.hstack([x0, y0, t_max], justify="start", gap=2)
|
||||
return t_max, x0, y0
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dfx, jit, jnp, p1, p2, p3, p4, p5, p6, p7, p8, t_max, term, x0, y0):
|
||||
@jit
|
||||
def solve_ode(y_init, args, t_end):
|
||||
solver = dfx.Tsit5()
|
||||
saveat = dfx.SaveAt(ts=jnp.linspace(0, t_end, 2000))
|
||||
sol = dfx.diffeqsolve(
|
||||
term,
|
||||
solver,
|
||||
t0=0,
|
||||
t1=t_end,
|
||||
dt0=0.01,
|
||||
y0=y_init,
|
||||
args=args,
|
||||
saveat=saveat,
|
||||
max_steps=100000,
|
||||
)
|
||||
return sol.ts, sol.ys
|
||||
|
||||
args = {
|
||||
"gmax": p1.value,
|
||||
"Erev": p2.value,
|
||||
"af": p3.value,
|
||||
"Ef": p4.value,
|
||||
"as": p5.value,
|
||||
"Es": p6.value,
|
||||
"ts": p7.value,
|
||||
"C": p8.value,
|
||||
}
|
||||
y_init = jnp.array([x0.value, y0.value])
|
||||
|
||||
t, ys = solve_ode(y_init, args, float(t_max.value))
|
||||
x_sol = ys[:, 0]
|
||||
y_sol = ys[:, 1]
|
||||
return args, t, x_sol, y_sol
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(args, jnp, np, plt, t, vector_field, x0, x_sol, y0, y_sol):
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# Time series
|
||||
|
||||
ax1.plot(t, x_sol, "b-", lw=1.5, label="x(t)")
|
||||
ax1.plot(t, y_sol, "r-", lw=1.5, label="y(t)")
|
||||
ax1.set_xlabel("Time t")
|
||||
ax1.set_ylabel("State")
|
||||
ax1.set_title("Transient Response")
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Phase plane bounds
|
||||
# pad = 1.0
|
||||
xmin, xmax = -4, 4
|
||||
ymin, ymax = -2.5, 2.5
|
||||
|
||||
# Vector field
|
||||
X, Y = jnp.meshgrid(jnp.linspace(xmin, xmax, 20), jnp.linspace(ymin, ymax, 20))
|
||||
U, V = jnp.zeros_like(X), np.zeros_like(Y)
|
||||
|
||||
state = jnp.stack([X, Y], axis=0)
|
||||
deriv = vector_field(0.0, state, args)
|
||||
dx, dy = deriv[0], deriv[1]
|
||||
mag = jnp.sqrt(dx**2 + dy**2)
|
||||
U = jnp.where(mag > 0, dx / mag, U)
|
||||
V = jnp.where(mag > 0, dy / mag, V)
|
||||
|
||||
ax2.quiver(X, Y, U, V, alpha=0.4, color="gray", scale=25)
|
||||
|
||||
# Nullclines
|
||||
Xf, Yf = jnp.meshgrid(jnp.linspace(xmin, xmax, 150), jnp.linspace(ymin, ymax, 150))
|
||||
DX, DY = jnp.zeros_like(Xf), jnp.zeros_like(Yf)
|
||||
state = jnp.stack([Xf, Yf], axis=0)
|
||||
deriv = vector_field(0.0, state, args)
|
||||
DX, DY = deriv[0], deriv[1]
|
||||
ax2.contour(
|
||||
Xf,
|
||||
Yf,
|
||||
DX,
|
||||
levels=[0],
|
||||
colors="blue",
|
||||
linestyles="--",
|
||||
linewidths=1.5,
|
||||
alpha=0.7,
|
||||
)
|
||||
ax2.contour(
|
||||
Xf, Yf, DY, levels=[0], colors="red", linestyles="--", linewidths=1.5, alpha=0.7
|
||||
)
|
||||
|
||||
# Trajectory
|
||||
ax2.plot(x_sol, y_sol, "b-", lw=2)
|
||||
ax2.plot(x0.value, y0.value, "go", ms=10, label="Start")
|
||||
ax2.plot(x_sol[-1], y_sol[-1], "r*", ms=12, label="End")
|
||||
|
||||
ax2.set_xlabel("x")
|
||||
ax2.set_ylabel("y")
|
||||
ax2.set_title("Phase Plane")
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_xlim(xmin, xmax)
|
||||
ax2.set_ylim(ymin, ymax)
|
||||
|
||||
plt.tight_layout()
|
||||
fig
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
847
scripts/examples/neuron_models/wererabbit/example.ipynb
Normal file
847
scripts/examples/neuron_models/wererabbit/example.ipynb
Normal file
File diff suppressed because one or more lines are too long
189
scripts/examples/neuron_models/wererabbit/wererabbit.ipynb
Normal file
189
scripts/examples/neuron_models/wererabbit/wererabbit.ipynb
Normal file
File diff suppressed because one or more lines are too long
318
scripts/examples/neuron_models/wererabbit/wererabbit.py
Normal file
318
scripts/examples/neuron_models/wererabbit/wererabbit.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.19.4"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import marimo as mo
|
||||
from wigglystuff import Slider2D
|
||||
|
||||
return Slider2D, mo
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import diffrax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
return diffrax, jax, jnp, np, plt
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(diffrax, jax, jnp):
|
||||
def vector_field(t, state, args):
|
||||
u, v = state
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
|
||||
z = jax.nn.tanh(kappa * (u - v))
|
||||
|
||||
# Prey dynamics
|
||||
du = z * (1 - alpha * jnp.exp(beta * v) * (1 + gamma * (0.5 - u))) - sigma
|
||||
|
||||
# Predator dynamics
|
||||
dv = z * (-1 + alpha * jnp.exp(beta * u) * u * (1 + gamma * (0.5 - v))) - sigma
|
||||
|
||||
return jnp.array([du, dv])
|
||||
|
||||
def vector_field_prod(t, state, args):
|
||||
u, v = state
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
|
||||
z = jax.nn.tanh(kappa * (u - v))
|
||||
|
||||
# Prey dynamics
|
||||
du = (
|
||||
z * (1 - alpha * jnp.exp(beta * v) * (1 + gamma * (0.5 - u))) - sigma
|
||||
# + sigma * jnp.maximum(0, delta - u) / (delta + 1e-16)
|
||||
)
|
||||
|
||||
# Predator dynamics
|
||||
dv = (
|
||||
z * (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.5 - v))) - sigma
|
||||
# + sigma * jnp.maximum(0, delta - v) / (delta + 1e-16)
|
||||
)
|
||||
|
||||
dv = jnp.where(jnp.allclose(z, 0.0), dv * jnp.sign(v), dv)
|
||||
du = jnp.where(jnp.allclose(z, 0.0), du * jnp.sign(u), du)
|
||||
|
||||
return jnp.array([du, dv])
|
||||
|
||||
def vector_field_exp(t, state, args):
|
||||
u, v = state
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
|
||||
z = jax.nn.tanh(kappa * (u - v))
|
||||
|
||||
# Prey dynamics
|
||||
du = (
|
||||
z * (1 - alpha * jnp.exp(beta * v) * (1 + gamma * (0.5 - u)))
|
||||
- sigma
|
||||
+ sigma * jnp.exp(-u / delta)
|
||||
)
|
||||
|
||||
# Predator dynamics
|
||||
dv = (
|
||||
z * (-1 + alpha * jnp.exp(beta * u) * u * (1 + gamma * (0.5 - v)))
|
||||
- sigma
|
||||
+ sigma * jnp.exp(-v / delta)
|
||||
)
|
||||
|
||||
return jnp.array([du, dv])
|
||||
|
||||
def physical_vector_field(t, state, args):
|
||||
x1, x2 = state
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
|
||||
In0 = 129e-15 # fixed by design
|
||||
C = 0.1e-12 # fixed by design
|
||||
kk = 0.39 # fixed by tech
|
||||
Ut = 0.025 # temperature dependent
|
||||
|
||||
Ibias = In0 / alpha
|
||||
|
||||
Ia = Ibias * sigma
|
||||
x3 = jax.nn.tanh(kappa * (x1 - x2))
|
||||
|
||||
dx1 = (
|
||||
x3 * Ibias
|
||||
- (In0 * jnp.exp(kk * x2 / Ut)) * (x3 + 26e-2 * (0.5 - x1) * x3)
|
||||
- Ia
|
||||
) / C
|
||||
dx2 = (
|
||||
-x3 * Ibias
|
||||
+ In0 * jnp.exp(kk * x1 / Ut) * (x3 + 26e-2 * (0.5 - x2) * x3)
|
||||
- Ia
|
||||
) / C
|
||||
|
||||
return jnp.array([dx1, dx2])
|
||||
|
||||
def compute_nullclines(vector_field, u_range, v_range, args, resolution=200):
|
||||
"""
|
||||
Compute nullclines
|
||||
du/dt = 0 (u-nullcline)
|
||||
dv/dt = 0 (v-nullcline)
|
||||
"""
|
||||
alpha, beta, gamma, kappa, sigma, delta = args
|
||||
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
|
||||
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
|
||||
U, V = jnp.meshgrid(u_vals, v_vals)
|
||||
|
||||
dU, dV = vector_field(0, [U, V], args)
|
||||
|
||||
return U, V, dU, dV
|
||||
|
||||
def solve(dyn, y0, p, T, n=1000):
|
||||
sol = diffrax.diffeqsolve(
|
||||
diffrax.ODETerm(dyn),
|
||||
diffrax.Tsit5(),
|
||||
t0=0.0,
|
||||
t1=T,
|
||||
dt0=0.01,
|
||||
y0=y0,
|
||||
args=p,
|
||||
saveat=diffrax.SaveAt(ts=jnp.linspace(0, T, n)),
|
||||
stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-8),
|
||||
max_steps=50000,
|
||||
)
|
||||
return sol.ts, sol.ys
|
||||
|
||||
return compute_nullclines, solve, vector_field_prod
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(Slider2D, mo):
|
||||
# alpha = 0.5 # I_n0 / I_bias ratio
|
||||
# beta = 0.39/0.025 # k / U_t (inverse thermal scale)
|
||||
# gamma = 0.26 # coupling coefficient
|
||||
# kappa = 5.0 # tanh steepness
|
||||
# sigma = 0.6 # bias scaling (s * I_bias normalized)
|
||||
# y0 = jnp.array([0.2, 0.4])
|
||||
# ts, ys = solve(vector_field, y0, params, 140)
|
||||
|
||||
alpha = mo.ui.slider(
|
||||
0.0004, 0.012, 0.00001, 0.00129, label="alpha", orientation="vertical"
|
||||
)
|
||||
beta = mo.ui.slider(
|
||||
0.0, 30, 0.00001, 0.39 / 0.025, label="beta", orientation="vertical"
|
||||
)
|
||||
gamma = mo.ui.slider(0, 1, 0.01, 0.26, label="gamma", orientation="vertical")
|
||||
kappa = mo.ui.slider(0, 10, 0.1, 5.0, label="kappa", orientation="vertical")
|
||||
sigma = mo.ui.slider(0, 1, 0.01, 0.6, label="sigma", orientation="vertical")
|
||||
delta = mo.ui.slider(0, 0.1, 0.001, 0.02, label="delta", orientation="vertical")
|
||||
|
||||
# v0 = mo.ui.slider(0, 1.0, 0.01, 0.3, label="v0")
|
||||
# u0 = mo.ui.slider(0, 1.0, 0.01, 0.2, label="u0", orientation="vertical")
|
||||
|
||||
state0 = mo.ui.anywidget(
|
||||
Slider2D(
|
||||
width=150,
|
||||
height=150,
|
||||
x_bounds=(-1.0, 1.5),
|
||||
y_bounds=(-1.0, 1.5),
|
||||
)
|
||||
)
|
||||
|
||||
mo.hstack(
|
||||
[
|
||||
mo.plain_text("""
|
||||
alpha: I_n0 / I_bias ratio
|
||||
beta: k / U_t ratio
|
||||
gamma: coupling coefficient
|
||||
kappa: tanh steepness
|
||||
sigma: bias scaling (s * I_bias)
|
||||
"""),
|
||||
mo.hstack(
|
||||
[state0, alpha, beta, gamma, kappa, sigma, delta], justify="start"
|
||||
),
|
||||
]
|
||||
)
|
||||
return alpha, beta, delta, gamma, kappa, sigma, state0
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(
|
||||
alpha,
|
||||
beta,
|
||||
compute_nullclines,
|
||||
delta,
|
||||
gamma,
|
||||
jnp,
|
||||
kappa,
|
||||
np,
|
||||
plt,
|
||||
sigma,
|
||||
solve,
|
||||
state0,
|
||||
vector_field_prod,
|
||||
):
|
||||
params = (
|
||||
alpha.value,
|
||||
beta.value,
|
||||
gamma.value,
|
||||
kappa.value,
|
||||
sigma.value,
|
||||
delta.value,
|
||||
)
|
||||
ic_neuro = [state0.x, state0.y]
|
||||
|
||||
u_range = [-1.0, 1.5]
|
||||
v_range = [-1.0, 1.5]
|
||||
|
||||
u_sparse = jnp.linspace(u_range[0], u_range[1], 20)
|
||||
v_sparse = jnp.linspace(v_range[0], v_range[1], 20)
|
||||
|
||||
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
|
||||
|
||||
def plot_vf(ax, vector_field):
|
||||
U, V, dU, dV = compute_nullclines(vector_field, u_range, v_range, params)
|
||||
dUs, dVs = vector_field(0, [Us, Vs], params)
|
||||
|
||||
# Normalize for visualization
|
||||
magnitude = np.sqrt(dUs**2 + dVs**2)
|
||||
magnitude[magnitude == 0] = 1
|
||||
dUs_norm = dUs / magnitude
|
||||
dVs_norm = dVs / magnitude
|
||||
|
||||
# Nullclines
|
||||
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=2, linestyles="-")
|
||||
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=2, linestyles="-")
|
||||
|
||||
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
|
||||
|
||||
# Trajectories
|
||||
color = plt.cm.plasma(0.2)
|
||||
|
||||
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, 50)
|
||||
ax.plot(ys[:, 0], ys[:, 1], "-", color=color, linewidth=1.5, alpha=0.8)
|
||||
ax.plot(ic_neuro[0], ic_neuro[1], "o", color=color, markersize=6)
|
||||
|
||||
ax.set_xlabel("u (Prey)")
|
||||
ax.set_ylabel("v (Predator)")
|
||||
ax.set_title("Wererabbit: Phase Portrait")
|
||||
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
|
||||
ax.set_xlim(u_range)
|
||||
ax.set_ylim(v_range)
|
||||
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
|
||||
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
|
||||
|
||||
def plot_trj(ax, vector_field):
|
||||
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, 50)
|
||||
ax.plot(ts, ys[:, 0], "b-", linewidth=2, label="u (Prey)")
|
||||
ax.plot(ts, ys[:, 1], "r-", linewidth=2, label="v (Predator)")
|
||||
|
||||
ax.set_xlabel("Time τ")
|
||||
ax.set_ylabel("Population")
|
||||
ax.set_title(
|
||||
f"Wererabbit: Time Series (IC: u₀={ic_neuro[0]:.2f}, v₀={ic_neuro[1]:.2f})"
|
||||
)
|
||||
ax.legend()
|
||||
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
|
||||
|
||||
fig = plt.figure(figsize=(10, 4))
|
||||
|
||||
# --- Plot 1: Wererabbit Phase Portrait ---
|
||||
ax1 = fig.add_subplot(1, 2, 1)
|
||||
plot_vf(ax1, vector_field_prod)
|
||||
|
||||
ax2 = fig.add_subplot(1, 2, 2)
|
||||
plot_trj(ax2, vector_field_prod)
|
||||
|
||||
# ax3 = fig.add_subplot(3, 2, 3)
|
||||
# plot_vf(ax3, vector_field_prod)
|
||||
|
||||
# ax4 = fig.add_subplot(3, 2, 4)
|
||||
# plot_trj(ax4, vector_field_prod)
|
||||
|
||||
# ax5 = fig.add_subplot(3, 2, 5)
|
||||
# plot_vf(ax5, vector_field_exp)
|
||||
|
||||
# ax6 = fig.add_subplot(3, 2, 6)
|
||||
# plot_trj(ax6, vector_field_exp)
|
||||
|
||||
plt.tight_layout()
|
||||
fig
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
158
scripts/networks/plot.ipynb
Normal file
158
scripts/networks/plot.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
scripts/networks/results/task1-2000-boomerang-False
Normal file
BIN
scripts/networks/results/task1-2000-boomerang-False
Normal file
Binary file not shown.
BIN
scripts/networks/results/task1-60000-boomerang-False
Normal file
BIN
scripts/networks/results/task1-60000-boomerang-False
Normal file
Binary file not shown.
BIN
scripts/networks/results/task1-60000-boomerang-False.eqx
Normal file
BIN
scripts/networks/results/task1-60000-boomerang-False.eqx
Normal file
Binary file not shown.
144
scripts/networks/test_methods.ipynb
Normal file
144
scripts/networks/test_methods.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
scripts/networks/tmp/task1-100000-boomerang-False
Normal file
BIN
scripts/networks/tmp/task1-100000-boomerang-False
Normal file
Binary file not shown.
BIN
scripts/networks/tmp/task1-100000-boomerang-False.eqx
Normal file
BIN
scripts/networks/tmp/task1-100000-boomerang-False.eqx
Normal file
Binary file not shown.
BIN
scripts/networks/tmp/task1-2000-boomerang-False
Normal file
BIN
scripts/networks/tmp/task1-2000-boomerang-False
Normal file
Binary file not shown.
BIN
scripts/networks/tmp/task1-2000-boomerang-False.eqx
Normal file
BIN
scripts/networks/tmp/task1-2000-boomerang-False.eqx
Normal file
Binary file not shown.
BIN
scripts/networks/tmp/task1-20000-boomerang-False
Normal file
BIN
scripts/networks/tmp/task1-20000-boomerang-False
Normal file
Binary file not shown.
BIN
scripts/networks/tmp/task1-20000-boomerang-False.eqx
Normal file
BIN
scripts/networks/tmp/task1-20000-boomerang-False.eqx
Normal file
Binary file not shown.
281
scripts/networks/train.py
Normal file
281
scripts/networks/train.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrandom
|
||||
import optax
|
||||
import pandas as pd
|
||||
from jaxtyping import Array, Float
|
||||
from optax import OptState
|
||||
from tqdm import trange
|
||||
|
||||
from felice.datasets.reasoning import ReasoningDataset
|
||||
from felice.networks import Implicit, Mamba, SequenceClassifier
|
||||
from felice.networks.implicit.boomerang import ImplicitBoomerang
|
||||
|
||||
|
||||
def compute_loss(
|
||||
model: eqx.Module, inputs: Array, targets: Array, masks: Array
|
||||
) -> Float[Array, ""]:
|
||||
def forward_single(inp, tgt, msk):
|
||||
logits = model(inp)
|
||||
loss = optax.softmax_cross_entropy_with_integer_labels(logits, tgt)
|
||||
return (loss * msk).sum() / (msk.sum() + 1e-8)
|
||||
|
||||
losses = jax.vmap(forward_single)(inputs, targets, masks)
|
||||
return losses.mean()
|
||||
|
||||
|
||||
v_and_grad = eqx.filter_value_and_grad(compute_loss)
|
||||
|
||||
|
||||
@eqx.filter_jit
|
||||
def compute_accuracy(
|
||||
model: eqx.Module, inputs: Array, targets: Array, masks: Array
|
||||
) -> Float[Array, ""]:
|
||||
def forward_single(inp, tgt, msk):
|
||||
logits = model(inp)
|
||||
preds = jnp.argmax(logits, axis=-1)
|
||||
correct = (preds == tgt) * msk
|
||||
return correct.sum(), msk.sum()
|
||||
|
||||
correct, total = jax.vmap(forward_single)(inputs, targets, masks)
|
||||
return correct.sum() / (total.sum() + 1e-8)
|
||||
|
||||
|
||||
@eqx.filter_jit
|
||||
def train_step(
|
||||
model: eqx.Module,
|
||||
opt_state: OptState,
|
||||
optimizer: Any,
|
||||
inputs: Array,
|
||||
targets: Array,
|
||||
masks: Array,
|
||||
) -> Tuple[eqx.Module, OptState, Array]:
|
||||
loss, grads = v_and_grad(model, inputs, targets, masks)
|
||||
updates, opt_state = optimizer.update(grads, opt_state, model)
|
||||
model = eqx.apply_updates(model, updates)
|
||||
return model, opt_state, loss
|
||||
|
||||
|
||||
def train_and_compare(
|
||||
model_type: Any,
|
||||
logdir: str,
|
||||
task_type: str = "simple",
|
||||
n_epochs: int = 1000,
|
||||
batch_size: int = 64,
|
||||
d_model: int = 64,
|
||||
d_state: int = 16,
|
||||
d_inner: int = 32,
|
||||
dt: float = 1.0,
|
||||
max_iters: int = 8,
|
||||
lr: float = 1e-3,
|
||||
seed: int = 42,
|
||||
# with_thr: bool = True,
|
||||
) -> Tuple[eqx.Module, eqx.Module, Array, Array, pd.DataFrame]:
|
||||
r"""Train Mamba and implicit model on the reasoning synthetic dataset.
|
||||
|
||||
Args:
|
||||
model_type: The type of the implicit model to train (Boomerang, Mamba Implicit).
|
||||
logdir: Directory and filenmae of the log.
|
||||
task_type: Type of task to solve from the reasoning synthetic dataset (simple, accumulation).
|
||||
n_epochs: Number of epochs to train.
|
||||
batch_size: Training batch size.
|
||||
d_model: Model dimensions including output.
|
||||
d_state: Model state dimension.
|
||||
d_inner: Model latent dimension.
|
||||
max_iters: Maximum number of iterations in the implicit model.
|
||||
lr: Learning rate.
|
||||
seed: Random seed.
|
||||
with_thr: For the Boomerang model, if using threshold for dual fixpoints.
|
||||
|
||||
Returns:
|
||||
The trained models (mamba and implicit) with the respective final accuracy and
|
||||
a pandas dataframe with the loss and accuracy per epoch.
|
||||
"""
|
||||
key = jrandom.key(seed)
|
||||
keys = jrandom.split(key, 4)
|
||||
|
||||
dataset = ReasoningDataset()
|
||||
|
||||
standard_model = SequenceClassifier(
|
||||
vocab_size=dataset.VOCAB_SIZE,
|
||||
d_model=d_model,
|
||||
d_state=d_state,
|
||||
d_inner=d_inner,
|
||||
model_class=Mamba,
|
||||
key=keys[0],
|
||||
)
|
||||
|
||||
implicit_model = SequenceClassifier(
|
||||
vocab_size=dataset.VOCAB_SIZE,
|
||||
d_model=d_model,
|
||||
d_state=d_state,
|
||||
d_inner=d_inner,
|
||||
model_class=model_type,
|
||||
max_iters=max_iters,
|
||||
dt=dt,
|
||||
# with_thr=with_thr,
|
||||
key=keys[1],
|
||||
)
|
||||
|
||||
# implicit_model = ImplicitBoomerang(
|
||||
# vocab_size=dataset.VOCAB_SIZE,
|
||||
# d_model=d_model,
|
||||
# d_state=d_state,
|
||||
# d_inner=d_inner,
|
||||
# max_iters=max_iters,
|
||||
# dt=dt,
|
||||
# # with_thr=with_thr,
|
||||
# key=keys[1],
|
||||
# )
|
||||
# Count parameters
|
||||
def count_params(model):
|
||||
return sum(
|
||||
x.size for x in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array))
|
||||
)
|
||||
|
||||
print(f"Mamba SSM params: {count_params(standard_model):,}")
|
||||
print(f"Implicit SSM params: {count_params(implicit_model):,}")
|
||||
|
||||
# Optimizers
|
||||
optimizer = optax.adam(lr)
|
||||
standard_opt_state = optimizer.init(eqx.filter(standard_model, eqx.is_array))
|
||||
implicit_opt_state = optimizer.init(eqx.filter(implicit_model, eqx.is_array))
|
||||
|
||||
# Training loop
|
||||
print(f"\nTraining on task: {task_type} with {max_iters} steps")
|
||||
print("=" * 60)
|
||||
|
||||
train_key = keys[2]
|
||||
|
||||
df = pd.DataFrame({"Epoch": [], "Loss": [], "Acc": [], "Model": []})
|
||||
pbar = trange(n_epochs)
|
||||
for epoch in pbar:
|
||||
train_key, batch_key = jrandom.split(train_key)
|
||||
inputs, targets, masks = dataset.generate_batch(
|
||||
batch_key, batch_size, task_type
|
||||
)
|
||||
|
||||
# Train standard model
|
||||
standard_model, standard_opt_state, standard_loss = train_step(
|
||||
standard_model, standard_opt_state, optimizer, inputs, targets, masks
|
||||
)
|
||||
|
||||
# Train implicit model
|
||||
implicit_model, implicit_opt_state, implicit_loss = train_step(
|
||||
implicit_model, implicit_opt_state, optimizer, inputs, targets, masks
|
||||
)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
# Evaluate on fresh batch
|
||||
eval_key = jrandom.fold_in(keys[3], epoch)
|
||||
eval_inputs, eval_targets, eval_masks = dataset.generate_batch(
|
||||
eval_key, batch_size, task_type
|
||||
)
|
||||
|
||||
standard_acc = compute_accuracy(
|
||||
standard_model, eval_inputs, eval_targets, eval_masks
|
||||
)
|
||||
implicit_acc = compute_accuracy(
|
||||
implicit_model, eval_inputs, eval_targets, eval_masks
|
||||
)
|
||||
|
||||
new_df = pd.DataFrame(
|
||||
{
|
||||
"Epoch": [epoch, epoch],
|
||||
"Loss": [standard_loss.item(), implicit_loss.item()],
|
||||
"Acc": [standard_acc.item(), implicit_acc.item()],
|
||||
"Model": ["Mamba", "Implicit"],
|
||||
}
|
||||
)
|
||||
df = pd.concat([df, new_df], ignore_index=True)
|
||||
df.to_pickle(logdir)
|
||||
pbar.write(
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Standard: loss={standard_loss:.4f}, acc={standard_acc:.4f} | "
|
||||
f"Implicit: loss={implicit_loss:.4f}, acc={implicit_acc:.4f}"
|
||||
)
|
||||
|
||||
# Final evaluation
|
||||
print("\n" + "=" * 60)
|
||||
print("Final Evaluation (1000 samples)")
|
||||
print("=" * 60)
|
||||
|
||||
eval_inputs, eval_targets, eval_masks = dataset.generate_batch(
|
||||
keys[4], 1000, task_type
|
||||
)
|
||||
|
||||
standard_acc = compute_accuracy(
|
||||
standard_model, eval_inputs, eval_targets, eval_masks
|
||||
)
|
||||
implicit_acc = compute_accuracy(
|
||||
implicit_model, eval_inputs, eval_targets, eval_masks
|
||||
)
|
||||
|
||||
print(f"Mamba SSM accuracy: {standard_acc:.4f}")
|
||||
print(f"Implicit SSM accuracy: {implicit_acc:.4f}")
|
||||
print(f"Improvement: {(implicit_acc - standard_acc) * 100:.2f}%")
|
||||
|
||||
return standard_model, implicit_model, standard_acc, implicit_acc, df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-t", type=int, choices=[1, 2], default=1, help="Task to perform"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
type=str,
|
||||
choices=["boomerang", "implicit"],
|
||||
default="implicit",
|
||||
help="Neuron model to use",
|
||||
)
|
||||
parser.add_argument("--dt", type=float, default=0.001, help="Simulation timestep")
|
||||
parser.add_argument("-i", type=int, default=8, help="Maximum number of iterations")
|
||||
parser.add_argument("-b", type=int, default=64, help="Batch size")
|
||||
parser.add_argument(
|
||||
"--thr", action="store_true", help="Using threshold on the boomerang neuron"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.m == "boomerang":
|
||||
model_type = ImplicitBoomerang
|
||||
elif args.m == "implicit":
|
||||
model_type = Implicit
|
||||
else:
|
||||
raise NotImplementedError(f"{args.t} model type not implemented")
|
||||
|
||||
logdir = os.path.join("tmp", f"task{args.t}-{args.i}-{args.m}-{args.thr}")
|
||||
if not os.path.exists("tmp"):
|
||||
os.makedirs("tmp")
|
||||
|
||||
print(f"Saving at {logdir}")
|
||||
_, implicit_model, std_acc1, imp_acc1, df = train_and_compare(
|
||||
model_type,
|
||||
logdir,
|
||||
task_type="simple" if args.t == 1 else "accumulation",
|
||||
n_epochs=1000,
|
||||
batch_size=64,
|
||||
d_model=ReasoningDataset.NUM_OUTPUT,
|
||||
d_state=16,
|
||||
d_inner=128,
|
||||
dt=args.dt,
|
||||
max_iters=args.i,
|
||||
# with_thr=args.thr,
|
||||
)
|
||||
eqx.tree_serialise_leaves(f"{logdir}.eqx", implicit_model)
|
||||
df.to_pickle(logdir)
|
||||
|
||||
print("=" * 70)
|
||||
print("SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f"{'Task':<25} {'Mamba SSM':<15} {'Implicit SSM':<15} {'Delta':<10}")
|
||||
print("-" * 70)
|
||||
print(
|
||||
f"{'Simple Comparison':<25} {std_acc1:<15.4f} {imp_acc1:<15.4f} {(imp_acc1 - std_acc1) * 100:>+.2f}%"
|
||||
)
|
||||
305
scripts/wererabbit_stability.ipynb
Normal file
305
scripts/wererabbit_stability.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user