Files
felice-models/scripts/examples/neuron_models/boomerang/boomerang.py
F.M. Quintana Velazquez 9fabbdefc0 Initial commit
Co-authored-by: Aradhana Dube <a.dube@rug.nl>
Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl>
Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl>
Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
2026-02-27 17:43:31 +01:00

248 lines
6.5 KiB
Python

import marimo
__generated_with = "0.19.4"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
from wigglystuff import Slider2D
return Slider2D, mo
@app.cell
def _():
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
return diffrax, jax, jnp, np, plt
@app.cell
def _(diffrax, jax, jnp):
def vector_field(t, state, args):
u, v = state
alpha, beta, gamma, kappa, sigma, delta = args
z = jax.nn.tanh(kappa * (v - u))
# Prey dynamics
du = (1 - alpha * jnp.exp(beta * v) * (1 - gamma * (0.3 - u))) + sigma * z
# Predator dynamics
dv = (-1 + alpha * jnp.exp(beta * u) * (1 + gamma * (0.3 - v))) + sigma * z
return jnp.array([du, dv])
def compute_nullclines(vector_field, u_range, v_range, args, resolution=200):
"""
Compute nullclines
du/dt = 0 (u-nullcline)
dv/dt = 0 (v-nullcline)
"""
alpha, beta, gamma, kappa, sigma, delta = args
u_vals = jnp.linspace(u_range[0], u_range[1], resolution)
v_vals = jnp.linspace(v_range[0], v_range[1], resolution)
U, V = jnp.meshgrid(u_vals, v_vals)
dU, dV = vector_field(0, [U, V], args)
return U, V, dU, dV
def solve(dyn, y0, p, T, n=500):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(dyn),
diffrax.Tsit5(),
t0=0.0,
t1=T,
dt0=0.0001,
y0=y0,
args=p,
saveat=diffrax.SaveAt(ts=jnp.linspace(0, T, n)),
stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-8),
max_steps=50000,
)
return sol.ts, sol.ys
return compute_nullclines, solve, vector_field
@app.cell
def _(Slider2D, mo):
# alpha = 0.5 # I_n0 / I_bias ratio
# beta = 0.39/0.025 # k / U_t (inverse thermal scale)
# gamma = 0.26 # coupling coefficient
# kappa = 5.0 # tanh steepness
# sigma = 0.6 # bias scaling (s * I_bias normalized)
# y0 = jnp.array([0.2, 0.4])
# ts, ys = solve(vector_field, y0, params, 140)
alpha = mo.ui.slider(
0.0004, 0.012, 0.00001, 0.00129, label="alpha", orientation="vertical"
)
beta = mo.ui.slider(
0.0, 30, 0.00001, 0.39 / 0.025, label="beta", orientation="vertical"
)
gamma = mo.ui.slider(0, 1, 0.01, 0.26, label="gamma", orientation="vertical")
kappa = mo.ui.slider(0, 30, 1.0, 10.0, label="kappa", orientation="vertical")
sigma = mo.ui.slider(0, 1, 0.01, 0.6, label="sigma", orientation="vertical")
delta = mo.ui.slider(1, 100.0, 1, 10, label="delta", orientation="vertical")
# v0 = mo.ui.slider(0, 1.0, 0.01, 0.3, label="v0")
# u0 = mo.ui.slider(0, 1.0, 0.01, 0.2, label="u0", orientation="vertical")
state0 = mo.ui.anywidget(
Slider2D(
x=0.34,
y=0.38,
width=150,
height=150,
x_bounds=(0.0, 0.6),
y_bounds=(0.0, 0.6),
)
)
mo.hstack(
[
mo.plain_text("""
alpha: I_n0 / I_bias ratio
beta: k / U_t ratio
gamma: coupling coefficient
kappa: tanh steepness
sigma: bias scaling (s * I_bias)
"""),
mo.hstack(
[state0, alpha, beta, gamma, kappa, sigma, delta], justify="start"
),
]
)
return alpha, beta, delta, gamma, kappa, sigma, state0
@app.cell
def _(
alpha,
beta,
compute_nullclines,
delta,
gamma,
jnp,
kappa,
np,
plt,
sigma,
solve,
state0,
vector_field,
):
params = (
alpha.value,
beta.value,
gamma.value,
kappa.value,
sigma.value,
delta.value,
)
ic_neuro = [state0.x, state0.y]
u_range = [0.0, 0.6]
v_range = [0.0, 0.6]
u_sparse = jnp.linspace(u_range[0], u_range[1], 20)
v_sparse = jnp.linspace(v_range[0], v_range[1], 20)
Us, Vs = jnp.meshgrid(u_sparse, v_sparse)
def plot_vf(ax, vector_field):
U, V, dU, dV = compute_nullclines(vector_field, u_range, v_range, params)
dUs, dVs = vector_field(0, [Us, Vs], params)
# Normalize for visualization
magnitude = np.sqrt(dUs**2 + dVs**2)
magnitude[magnitude == 0] = 1
dUs_norm = dUs / magnitude
dVs_norm = dVs / magnitude
# Nullclines
ax.contour(U, V, dU, levels=[0], colors="blue", linewidths=2, linestyles="-")
ax.contour(U, V, dV, levels=[0], colors="red", linewidths=2, linestyles="-")
ax.quiver(Us, Vs, dUs_norm, dVs_norm, magnitude, cmap="viridis", alpha=0.6)
# Trajectories
color = plt.cm.plasma(0.2)
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, delta.value)
ax.plot(ys[:, 0], ys[:, 1], "-", color=color, linewidth=1.5, alpha=0.8)
ax.plot(ic_neuro[0], ic_neuro[1], "o", color=color, markersize=6)
ax.set_xlabel("u (Prey)")
ax.set_ylabel("v (Predator)")
ax.set_title("Wererabbit: Phase Portrait")
ax.legend(["u-nullcline (du/dt=0)", "v-nullcline (dv/dt=0)"], loc="upper right")
ax.set_xlim(u_range)
ax.set_ylim(v_range)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.3)
def plot_trj(ax, vector_field):
ts, ys = solve(vector_field, jnp.array(ic_neuro), params, delta.value)
ax.plot(ts, ys[:, 0], "b-", linewidth=2, label="u (Prey)")
ax.plot(ts, ys[:, 1], "r-", linewidth=2, label="v (Predator)")
ax.set_xlabel("Time τ")
ax.set_ylabel("Population")
ax.set_title(
f"Wererabbit: Time Series (IC: u₀={ic_neuro[0]:.2f}, v₀={ic_neuro[1]:.2f})"
)
ax.legend()
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
fig = plt.figure(figsize=(10, 4))
# --- Plot 1: Wererabbit Phase Portrait ---
ax1 = fig.add_subplot(1, 2, 1)
plot_vf(ax1, vector_field)
ax2 = fig.add_subplot(1, 2, 2)
plot_trj(ax2, vector_field)
# ax3 = fig.add_subplot(3, 2, 3)
# plot_vf(ax3, vector_field_prod)
# ax4 = fig.add_subplot(3, 2, 4)
# plot_trj(ax4, vector_field_prod)
# ax5 = fig.add_subplot(3, 2, 5)
# plot_vf(ax5, vector_field_exp)
# ax6 = fig.add_subplot(3, 2, 6)
# plot_trj(ax6, vector_field_exp)
plt.tight_layout()
fig
return
@app.cell
def _():
return
@app.cell
def _():
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()