mirror of
https://github.com/bics-rug/felice-models.git
synced 2026-03-10 13:07:40 +01:00
Co-authored-by: Aradhana Dube <a.dube@rug.nl> Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl> Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl> Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
189 lines
4.7 KiB
Python
189 lines
4.7 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.19.4"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
import diffrax as dfx
|
|
import jax.numpy as jnp
|
|
import marimo as mo
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from jax import jit
|
|
|
|
return dfx, jit, jnp, mo, np, plt
|
|
|
|
|
|
@app.cell
|
|
def _(dfx, jnp):
|
|
def vector_field(t, y, args):
|
|
v, vslow = y
|
|
|
|
ipasive = args["gmax"] * (v - args["Erev"])
|
|
ifast = args["af"] * jnp.tanh(v - args["Ef"])
|
|
islow = args["as"] * jnp.tanh(vslow - args["Es"])
|
|
|
|
dv = (-ipasive - ifast - islow) / args["C"]
|
|
dvs = (v - vslow) / args["ts"]
|
|
|
|
return jnp.array([dv, dvs])
|
|
|
|
term = dfx.ODETerm(vector_field)
|
|
return term, vector_field
|
|
|
|
|
|
@app.cell
|
|
def _(mo):
|
|
p1 = mo.ui.slider(0.0, 5.0, value=1.0, step=0.1, label="gmax")
|
|
p2 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.1, label="Erev")
|
|
p3 = mo.ui.slider(-5.0, 5.0, value=-2.0, step=0.05, label="af")
|
|
p4 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Ef")
|
|
p5 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.05, label="as")
|
|
p6 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Es")
|
|
p7 = mo.ui.slider(1.0, 100.0, value=50.0, step=0.1, label="ts")
|
|
p8 = mo.ui.slider(0.0, 1.0, value=1.0, step=0.01, label="C")
|
|
|
|
mo.hstack(
|
|
[
|
|
mo.vstack([p1, p2, p3, p4], justify="start", gap=1),
|
|
mo.vstack([p5, p6, p7, p8], justify="start", gap=1),
|
|
]
|
|
)
|
|
return p1, p2, p3, p4, p5, p6, p7, p8
|
|
|
|
|
|
@app.cell
|
|
def _(mo):
|
|
mo.md("""
|
|
### Initial Conditions & Simulation
|
|
""")
|
|
return
|
|
|
|
|
|
@app.cell
|
|
def _(mo):
|
|
x0 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.1, label="x₀")
|
|
y0 = mo.ui.slider(-5.0, 5.0, value=0.0, step=0.1, label="y₀")
|
|
t_max = mo.ui.slider(10, 100, value=30, step=5, label="t_max")
|
|
|
|
mo.hstack([x0, y0, t_max], justify="start", gap=2)
|
|
return t_max, x0, y0
|
|
|
|
|
|
@app.cell
|
|
def _(dfx, jit, jnp, p1, p2, p3, p4, p5, p6, p7, p8, t_max, term, x0, y0):
|
|
@jit
|
|
def solve_ode(y_init, args, t_end):
|
|
solver = dfx.Tsit5()
|
|
saveat = dfx.SaveAt(ts=jnp.linspace(0, t_end, 2000))
|
|
sol = dfx.diffeqsolve(
|
|
term,
|
|
solver,
|
|
t0=0,
|
|
t1=t_end,
|
|
dt0=0.01,
|
|
y0=y_init,
|
|
args=args,
|
|
saveat=saveat,
|
|
max_steps=100000,
|
|
)
|
|
return sol.ts, sol.ys
|
|
|
|
args = {
|
|
"gmax": p1.value,
|
|
"Erev": p2.value,
|
|
"af": p3.value,
|
|
"Ef": p4.value,
|
|
"as": p5.value,
|
|
"Es": p6.value,
|
|
"ts": p7.value,
|
|
"C": p8.value,
|
|
}
|
|
y_init = jnp.array([x0.value, y0.value])
|
|
|
|
t, ys = solve_ode(y_init, args, float(t_max.value))
|
|
x_sol = ys[:, 0]
|
|
y_sol = ys[:, 1]
|
|
return args, t, x_sol, y_sol
|
|
|
|
|
|
@app.cell
|
|
def _(args, jnp, np, plt, t, vector_field, x0, x_sol, y0, y_sol):
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
|
|
|
# Time series
|
|
|
|
ax1.plot(t, x_sol, "b-", lw=1.5, label="x(t)")
|
|
ax1.plot(t, y_sol, "r-", lw=1.5, label="y(t)")
|
|
ax1.set_xlabel("Time t")
|
|
ax1.set_ylabel("State")
|
|
ax1.set_title("Transient Response")
|
|
ax1.legend()
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
# Phase plane bounds
|
|
# pad = 1.0
|
|
xmin, xmax = -4, 4
|
|
ymin, ymax = -2.5, 2.5
|
|
|
|
# Vector field
|
|
X, Y = jnp.meshgrid(jnp.linspace(xmin, xmax, 20), jnp.linspace(ymin, ymax, 20))
|
|
U, V = jnp.zeros_like(X), np.zeros_like(Y)
|
|
|
|
state = jnp.stack([X, Y], axis=0)
|
|
deriv = vector_field(0.0, state, args)
|
|
dx, dy = deriv[0], deriv[1]
|
|
mag = jnp.sqrt(dx**2 + dy**2)
|
|
U = jnp.where(mag > 0, dx / mag, U)
|
|
V = jnp.where(mag > 0, dy / mag, V)
|
|
|
|
ax2.quiver(X, Y, U, V, alpha=0.4, color="gray", scale=25)
|
|
|
|
# Nullclines
|
|
Xf, Yf = jnp.meshgrid(jnp.linspace(xmin, xmax, 150), jnp.linspace(ymin, ymax, 150))
|
|
DX, DY = jnp.zeros_like(Xf), jnp.zeros_like(Yf)
|
|
state = jnp.stack([Xf, Yf], axis=0)
|
|
deriv = vector_field(0.0, state, args)
|
|
DX, DY = deriv[0], deriv[1]
|
|
ax2.contour(
|
|
Xf,
|
|
Yf,
|
|
DX,
|
|
levels=[0],
|
|
colors="blue",
|
|
linestyles="--",
|
|
linewidths=1.5,
|
|
alpha=0.7,
|
|
)
|
|
ax2.contour(
|
|
Xf, Yf, DY, levels=[0], colors="red", linestyles="--", linewidths=1.5, alpha=0.7
|
|
)
|
|
|
|
# Trajectory
|
|
ax2.plot(x_sol, y_sol, "b-", lw=2)
|
|
ax2.plot(x0.value, y0.value, "go", ms=10, label="Start")
|
|
ax2.plot(x_sol[-1], y_sol[-1], "r*", ms=12, label="End")
|
|
|
|
ax2.set_xlabel("x")
|
|
ax2.set_ylabel("y")
|
|
ax2.set_title("Phase Plane")
|
|
ax2.legend()
|
|
ax2.grid(True, alpha=0.3)
|
|
ax2.set_xlim(xmin, xmax)
|
|
ax2.set_ylim(ymin, ymax)
|
|
|
|
plt.tight_layout()
|
|
fig
|
|
return
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|