Files
F.M. Quintana Velazquez 9fabbdefc0 Initial commit
Co-authored-by: Aradhana Dube <a.dube@rug.nl>
Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl>
Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl>
Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
2026-02-27 17:43:31 +01:00

189 lines
4.7 KiB
Python

import marimo
__generated_with = "0.19.4"
app = marimo.App(width="medium")
@app.cell
def _():
import diffrax as dfx
import jax.numpy as jnp
import marimo as mo
import matplotlib.pyplot as plt
import numpy as np
from jax import jit
return dfx, jit, jnp, mo, np, plt
@app.cell
def _(dfx, jnp):
def vector_field(t, y, args):
v, vslow = y
ipasive = args["gmax"] * (v - args["Erev"])
ifast = args["af"] * jnp.tanh(v - args["Ef"])
islow = args["as"] * jnp.tanh(vslow - args["Es"])
dv = (-ipasive - ifast - islow) / args["C"]
dvs = (v - vslow) / args["ts"]
return jnp.array([dv, dvs])
term = dfx.ODETerm(vector_field)
return term, vector_field
@app.cell
def _(mo):
p1 = mo.ui.slider(0.0, 5.0, value=1.0, step=0.1, label="gmax")
p2 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.1, label="Erev")
p3 = mo.ui.slider(-5.0, 5.0, value=-2.0, step=0.05, label="af")
p4 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Ef")
p5 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.05, label="as")
p6 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Es")
p7 = mo.ui.slider(1.0, 100.0, value=50.0, step=0.1, label="ts")
p8 = mo.ui.slider(0.0, 1.0, value=1.0, step=0.01, label="C")
mo.hstack(
[
mo.vstack([p1, p2, p3, p4], justify="start", gap=1),
mo.vstack([p5, p6, p7, p8], justify="start", gap=1),
]
)
return p1, p2, p3, p4, p5, p6, p7, p8
@app.cell
def _(mo):
mo.md("""
### Initial Conditions & Simulation
""")
return
@app.cell
def _(mo):
x0 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.1, label="x₀")
y0 = mo.ui.slider(-5.0, 5.0, value=0.0, step=0.1, label="y₀")
t_max = mo.ui.slider(10, 100, value=30, step=5, label="t_max")
mo.hstack([x0, y0, t_max], justify="start", gap=2)
return t_max, x0, y0
@app.cell
def _(dfx, jit, jnp, p1, p2, p3, p4, p5, p6, p7, p8, t_max, term, x0, y0):
@jit
def solve_ode(y_init, args, t_end):
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=jnp.linspace(0, t_end, 2000))
sol = dfx.diffeqsolve(
term,
solver,
t0=0,
t1=t_end,
dt0=0.01,
y0=y_init,
args=args,
saveat=saveat,
max_steps=100000,
)
return sol.ts, sol.ys
args = {
"gmax": p1.value,
"Erev": p2.value,
"af": p3.value,
"Ef": p4.value,
"as": p5.value,
"Es": p6.value,
"ts": p7.value,
"C": p8.value,
}
y_init = jnp.array([x0.value, y0.value])
t, ys = solve_ode(y_init, args, float(t_max.value))
x_sol = ys[:, 0]
y_sol = ys[:, 1]
return args, t, x_sol, y_sol
@app.cell
def _(args, jnp, np, plt, t, vector_field, x0, x_sol, y0, y_sol):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Time series
ax1.plot(t, x_sol, "b-", lw=1.5, label="x(t)")
ax1.plot(t, y_sol, "r-", lw=1.5, label="y(t)")
ax1.set_xlabel("Time t")
ax1.set_ylabel("State")
ax1.set_title("Transient Response")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Phase plane bounds
# pad = 1.0
xmin, xmax = -4, 4
ymin, ymax = -2.5, 2.5
# Vector field
X, Y = jnp.meshgrid(jnp.linspace(xmin, xmax, 20), jnp.linspace(ymin, ymax, 20))
U, V = jnp.zeros_like(X), np.zeros_like(Y)
state = jnp.stack([X, Y], axis=0)
deriv = vector_field(0.0, state, args)
dx, dy = deriv[0], deriv[1]
mag = jnp.sqrt(dx**2 + dy**2)
U = jnp.where(mag > 0, dx / mag, U)
V = jnp.where(mag > 0, dy / mag, V)
ax2.quiver(X, Y, U, V, alpha=0.4, color="gray", scale=25)
# Nullclines
Xf, Yf = jnp.meshgrid(jnp.linspace(xmin, xmax, 150), jnp.linspace(ymin, ymax, 150))
DX, DY = jnp.zeros_like(Xf), jnp.zeros_like(Yf)
state = jnp.stack([Xf, Yf], axis=0)
deriv = vector_field(0.0, state, args)
DX, DY = deriv[0], deriv[1]
ax2.contour(
Xf,
Yf,
DX,
levels=[0],
colors="blue",
linestyles="--",
linewidths=1.5,
alpha=0.7,
)
ax2.contour(
Xf, Yf, DY, levels=[0], colors="red", linestyles="--", linewidths=1.5, alpha=0.7
)
# Trajectory
ax2.plot(x_sol, y_sol, "b-", lw=2)
ax2.plot(x0.value, y0.value, "go", ms=10, label="Start")
ax2.plot(x_sol[-1], y_sol[-1], "r*", ms=12, label="End")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_title("Phase Plane")
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim(xmin, xmax)
ax2.set_ylim(ymin, ymax)
plt.tight_layout()
fig
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()