mirror of
https://github.com/bics-rug/felice-models.git
synced 2026-03-10 21:14:15 +01:00
Initial commit
Co-authored-by: Aradhana Dube <a.dube@rug.nl> Co-authored-by: Renzo I. Barraza Altamirano <r.i.barraza.altamirano@rug.nl> Co-authored-by: Paolo Gibertini <p.gibertini@rug.nl> Co-authored-by: Luca D. Fehlings <l.d.fehlings@rug.nl>
This commit is contained in:
240
scripts/examples/neuron_models/fhn/fhnrs.ipynb
Normal file
240
scripts/examples/neuron_models/fhn/fhnrs.ipynb
Normal file
File diff suppressed because one or more lines are too long
188
scripts/examples/neuron_models/fhn/fhnrs.py
Normal file
188
scripts/examples/neuron_models/fhn/fhnrs.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.19.4"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import diffrax as dfx
|
||||
import jax.numpy as jnp
|
||||
import marimo as mo
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from jax import jit
|
||||
|
||||
return dfx, jit, jnp, mo, np, plt
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dfx, jnp):
|
||||
def vector_field(t, y, args):
|
||||
v, vslow = y
|
||||
|
||||
ipasive = args["gmax"] * (v - args["Erev"])
|
||||
ifast = args["af"] * jnp.tanh(v - args["Ef"])
|
||||
islow = args["as"] * jnp.tanh(vslow - args["Es"])
|
||||
|
||||
dv = (-ipasive - ifast - islow) / args["C"]
|
||||
dvs = (v - vslow) / args["ts"]
|
||||
|
||||
return jnp.array([dv, dvs])
|
||||
|
||||
term = dfx.ODETerm(vector_field)
|
||||
return term, vector_field
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
p1 = mo.ui.slider(0.0, 5.0, value=1.0, step=0.1, label="gmax")
|
||||
p2 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.1, label="Erev")
|
||||
p3 = mo.ui.slider(-5.0, 5.0, value=-2.0, step=0.05, label="af")
|
||||
p4 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Ef")
|
||||
p5 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.05, label="as")
|
||||
p6 = mo.ui.slider(-1.0, 1.0, value=0.0, step=0.05, label="Es")
|
||||
p7 = mo.ui.slider(1.0, 100.0, value=50.0, step=0.1, label="ts")
|
||||
p8 = mo.ui.slider(0.0, 1.0, value=1.0, step=0.01, label="C")
|
||||
|
||||
mo.hstack(
|
||||
[
|
||||
mo.vstack([p1, p2, p3, p4], justify="start", gap=1),
|
||||
mo.vstack([p5, p6, p7, p8], justify="start", gap=1),
|
||||
]
|
||||
)
|
||||
return p1, p2, p3, p4, p5, p6, p7, p8
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
mo.md("""
|
||||
### Initial Conditions & Simulation
|
||||
""")
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
x0 = mo.ui.slider(-5.0, 5.0, value=2.0, step=0.1, label="x₀")
|
||||
y0 = mo.ui.slider(-5.0, 5.0, value=0.0, step=0.1, label="y₀")
|
||||
t_max = mo.ui.slider(10, 100, value=30, step=5, label="t_max")
|
||||
|
||||
mo.hstack([x0, y0, t_max], justify="start", gap=2)
|
||||
return t_max, x0, y0
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dfx, jit, jnp, p1, p2, p3, p4, p5, p6, p7, p8, t_max, term, x0, y0):
|
||||
@jit
|
||||
def solve_ode(y_init, args, t_end):
|
||||
solver = dfx.Tsit5()
|
||||
saveat = dfx.SaveAt(ts=jnp.linspace(0, t_end, 2000))
|
||||
sol = dfx.diffeqsolve(
|
||||
term,
|
||||
solver,
|
||||
t0=0,
|
||||
t1=t_end,
|
||||
dt0=0.01,
|
||||
y0=y_init,
|
||||
args=args,
|
||||
saveat=saveat,
|
||||
max_steps=100000,
|
||||
)
|
||||
return sol.ts, sol.ys
|
||||
|
||||
args = {
|
||||
"gmax": p1.value,
|
||||
"Erev": p2.value,
|
||||
"af": p3.value,
|
||||
"Ef": p4.value,
|
||||
"as": p5.value,
|
||||
"Es": p6.value,
|
||||
"ts": p7.value,
|
||||
"C": p8.value,
|
||||
}
|
||||
y_init = jnp.array([x0.value, y0.value])
|
||||
|
||||
t, ys = solve_ode(y_init, args, float(t_max.value))
|
||||
x_sol = ys[:, 0]
|
||||
y_sol = ys[:, 1]
|
||||
return args, t, x_sol, y_sol
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(args, jnp, np, plt, t, vector_field, x0, x_sol, y0, y_sol):
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# Time series
|
||||
|
||||
ax1.plot(t, x_sol, "b-", lw=1.5, label="x(t)")
|
||||
ax1.plot(t, y_sol, "r-", lw=1.5, label="y(t)")
|
||||
ax1.set_xlabel("Time t")
|
||||
ax1.set_ylabel("State")
|
||||
ax1.set_title("Transient Response")
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Phase plane bounds
|
||||
# pad = 1.0
|
||||
xmin, xmax = -4, 4
|
||||
ymin, ymax = -2.5, 2.5
|
||||
|
||||
# Vector field
|
||||
X, Y = jnp.meshgrid(jnp.linspace(xmin, xmax, 20), jnp.linspace(ymin, ymax, 20))
|
||||
U, V = jnp.zeros_like(X), np.zeros_like(Y)
|
||||
|
||||
state = jnp.stack([X, Y], axis=0)
|
||||
deriv = vector_field(0.0, state, args)
|
||||
dx, dy = deriv[0], deriv[1]
|
||||
mag = jnp.sqrt(dx**2 + dy**2)
|
||||
U = jnp.where(mag > 0, dx / mag, U)
|
||||
V = jnp.where(mag > 0, dy / mag, V)
|
||||
|
||||
ax2.quiver(X, Y, U, V, alpha=0.4, color="gray", scale=25)
|
||||
|
||||
# Nullclines
|
||||
Xf, Yf = jnp.meshgrid(jnp.linspace(xmin, xmax, 150), jnp.linspace(ymin, ymax, 150))
|
||||
DX, DY = jnp.zeros_like(Xf), jnp.zeros_like(Yf)
|
||||
state = jnp.stack([Xf, Yf], axis=0)
|
||||
deriv = vector_field(0.0, state, args)
|
||||
DX, DY = deriv[0], deriv[1]
|
||||
ax2.contour(
|
||||
Xf,
|
||||
Yf,
|
||||
DX,
|
||||
levels=[0],
|
||||
colors="blue",
|
||||
linestyles="--",
|
||||
linewidths=1.5,
|
||||
alpha=0.7,
|
||||
)
|
||||
ax2.contour(
|
||||
Xf, Yf, DY, levels=[0], colors="red", linestyles="--", linewidths=1.5, alpha=0.7
|
||||
)
|
||||
|
||||
# Trajectory
|
||||
ax2.plot(x_sol, y_sol, "b-", lw=2)
|
||||
ax2.plot(x0.value, y0.value, "go", ms=10, label="Start")
|
||||
ax2.plot(x_sol[-1], y_sol[-1], "r*", ms=12, label="End")
|
||||
|
||||
ax2.set_xlabel("x")
|
||||
ax2.set_ylabel("y")
|
||||
ax2.set_title("Phase Plane")
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_xlim(xmin, xmax)
|
||||
ax2.set_ylim(ymin, ymax)
|
||||
|
||||
plt.tight_layout()
|
||||
fig
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
Reference in New Issue
Block a user