import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.integrate import odeint as scipy_odeint
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
plt.rcParams.update({
"figure.dpi": 130,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.grid": True,
"grid.alpha": 0.3,
"font.size": 8,
"axes.titlesize": 8,
"axes.labelsize": 8,
"xtick.labelsize": 7,
"ytick.labelsize": 7,
"legend.fontsize": 7,
})
C_S = "#2196F3"
C_I = "#F44336"
C_R = "#4CAF50"
C_OBS = "black"
C_POST = "#9C27B0" # purple — posterior median
C_PI = "#CE93D8" # light purple — credible bandBayesian Physics-Informed Neural Networks
Posterior uncertainty over epidemic parameters via variational inference
1 From point estimates to posteriors
The SIR PINN tutorial recovered point estimates of \(\beta\) and \(\gamma\). Here we exlore the distribution of \(\beta\) consistent with these data.
This distinction matters for three concrete reasons:
This post extends the PINN to a Bayesian PINN (B-PINN) (1): \(\beta\) and \(\gamma\) become random variables, and we learn a variational posterior from the data using the reparameterisation trick (2).
What you need
pip install torch scipy matplotlib numpyTested with torch 2.11, scipy 1.17, Python 3.11. Training runs in ~10 min on CPU.
2 The Bayesian model
2.1 Prior
We place weakly informative log-normal priors on both epidemic parameters:
\[ \beta \sim \text{LogNormal}(\log 0.3,\; 0.5^2), \qquad \gamma \sim \text{LogNormal}(\log 0.1,\; 0.5^2) \]
These priors place 95% of prior mass on \(\beta \in [0.11,\, 0.81]\) and \(\gamma \in [0.04,\, 0.27]\) — weakly informative over the plausible range for short-duration respiratory outbreaks, while excluding biologically implausible values.
2.2 Likelihood
Given \((\beta, \gamma)\), the PINN network \(f_\theta(t,\,\beta,\,\gamma)\) maps time to compartment fractions \((S, I, R)\). We observe \(I(t)\) with additive Gaussian noise:
\[ I(t_j) \mid \beta,\gamma \;\sim\; \mathcal{N}\!\bigl(N \cdot f^I_\theta(t_j,\beta,\gamma),\; \sigma_\text{obs}^2 \bigr) \]
The physics constraints act as an implicit regulariser on the likelihood: network outputs that violate the SIR equations are penalised regardless of data fit.
2.3 Variational posterior
Exact posterior inference is intractable. We approximate with a mean-field variational family (3):
\[ q(\beta, \gamma) = \text{LogNormal}(\mu_\beta,\,\sigma_\beta^2) \times \text{LogNormal}(\mu_\gamma,\,\sigma_\gamma^2) \]
where \((\mu_\beta, \sigma_\beta, \mu_\gamma, \sigma_\gamma)\) are learnable variational parameters, with \(\sigma = \operatorname{softplus}(\rho)\) enforcing positivity.
3 Evidence Lower BOund (ELBO)
Maximising \(\log p(\text{data})\) is equivalent to maximising the ELBO (3):
\[ \mathcal{L}_\text{ELBO} = \underbrace{\mathbb{E}_q[\log p(\text{data} \mid \beta, \gamma, \theta)]}_{\text{expected data fit}} \;-\; \underbrace{\mathrm{KL}[q(\beta,\gamma) \;\|\; p(\beta,\gamma)]}_{\text{regularise toward prior}} \]
We augment with the physics residual and initial-condition terms from the standard PINN:
\[ \mathcal{L} = \mathbb{E}_q[\mathcal{L}_\text{data}] + \lambda_\phi\,\mathbb{E}_q[\mathcal{L}_\text{physics}] + \lambda_0\,\mathcal{L}_\text{IC} + \lambda_\text{KL}\,\mathrm{KL}[q \| p] \]
The expectations are estimated with \(K\) Monte Carlo samples per step using the reparameterisation trick (2):
\[ \log\beta^{(k)} = \mu_\beta + \sigma_\beta \cdot \varepsilon^{(k)}, \qquad \varepsilon^{(k)} \sim \mathcal{N}(0,1) \]
This keeps gradients flowing back through \(\mu_\beta\) and \(\sigma_\beta\).
KL between two log-normals
For \(q = \mathcal{N}(\mu_q, \sigma_q^2)\) and \(p = \mathcal{N}(\mu_p, \sigma_p^2)\) in log-parameter space:
\[ \mathrm{KL}[q \| p] = \log\frac{\sigma_p}{\sigma_q} + \frac{\sigma_q^2 + (\mu_q - \mu_p)^2}{2\,\sigma_p^2} - \frac{1}{2} \]
3.1 Why a conditional network?
In the standard PINN, the network maps \(t \to (S, I, R)\) and \(\beta\), \(\gamma\) are scalar parameters. If we made those scalars variational without changing the network, every Monte Carlo sample of \((\beta, \gamma)\) would drive the same network output, producing identical trajectories — no predictive uncertainty.
The fix: the network takes \((\,t,\;\log\beta,\;\log\gamma\,)\) as a three-input conditional ODE solver. For any \((\beta, \gamma)\) in the prior support, it outputs the corresponding SIR trajectory. Sampling \(K\) parameter vectors per training step forces the network to be simultaneously consistent with the physics for all sampled parameters.
| Standard PINN | Bayesian PINN (this post) | |
|---|---|---|
| Network input | \(t\) | \((t,\,\log\beta,\,\log\gamma)\) |
| \(\beta\), \(\gamma\) | Learnable scalars | Variational distributions |
| Trajectory uncertainty | None | Posterior predictive intervals |
| Parameter uncertainty | None | Marginal posteriors |
| Training cost | \(1\times\) | \(\approx K\times\) (K=2 here) |
4 Setup
4.1 Ground truth and observations
Identical setup to the PINN tutorial: \(N = 1{,}000\), \(\beta = 0.3\), \(\gamma = 0.1\) (\(R_0 = 3\)), 30 noisy observations of \(I(t)\) only.
N = 1_000
beta_true = 0.3
gamma_true = 0.1
T_max = 60.0
y0 = np.array([990.0, 10.0, 0.0])
def sir_rhs(y, t, beta, gamma, N):
S, I, R = y
dS = -beta * S * I / N
dI = beta * S * I / N - gamma * I
dR = gamma * I
return [dS, dI, dR]
n_grid = 100
t_grid = np.linspace(0, T_max, n_grid)
sol = scipy_odeint(sir_rhs, y0, t_grid, args=(beta_true, gamma_true, N))
S_true, I_true, R_true = sol.T
n_obs = 30
obs_idx = np.sort(np.random.choice(np.arange(1, n_grid), n_obs, replace=False))
t_obs = t_grid[obs_idx]
I_obs = np.clip(I_true[obs_idx] + np.random.normal(0, 20.0, n_obs), 0, N)
# Scaled tensors
t_obs_s = torch.tensor(t_obs / T_max, dtype=torch.float32)
I_obs_s = torch.tensor(I_obs / N, dtype=torch.float32)
y0_s = torch.tensor(y0 / N, dtype=torch.float32)
t_col = torch.linspace(0, 1, 150) # collocation grid for physics loss5 Bayesian PINN
5.1 Architecture
class BayesianPINN(nn.Module):
"""
Conditional PINN: maps (t_scaled, log_beta, log_gamma) -> (S, I, R) fractions.
Variational distribution
------------------------
q(log beta) = N(mu_lb, softplus(rho_lb)^2)
q(log gamma) = N(mu_lg, softplus(rho_lg)^2)
Initialised intentionally away from the true values so convergence is visible:
beta_init = 0.5 (true: 0.3), gamma_init = 0.15 (true: 0.1)
"""
def __init__(self, hidden: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(3, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh(),
nn.Linear(hidden, 3),
nn.Sigmoid(), # keeps outputs in (0, 1)
)
# Variational means — initialised away from truth
self.mu_lb = nn.Parameter(torch.tensor(float(np.log(0.50))))
self.rho_lb = nn.Parameter(torch.tensor(-1.5)) # softplus(-1.5) ≈ 0.20
self.mu_lg = nn.Parameter(torch.tensor(float(np.log(0.15))))
self.rho_lg = nn.Parameter(torch.tensor(-1.5))
# ---- variational standard deviations (log-space) -----------------------
@property
def sigma_lb(self): return F.softplus(self.rho_lb) + 1e-4
@property
def sigma_lg(self): return F.softplus(self.rho_lg) + 1e-4
# ---- posterior summaries (physical units) ------------------------------
@property
def beta_mean(self): return float(torch.exp(self.mu_lb))
@property
def gamma_mean(self): return float(torch.exp(self.mu_lg))
@property
def R0_mean(self): return self.beta_mean / self.gamma_mean
# ---- reparameterisation sampling ---------------------------------------
def sample_log_params(self, K: int = 1):
"""Draw K samples via the reparameterisation trick."""
lb = self.mu_lb + self.sigma_lb * torch.randn(K) # log beta ~ N(mu, sig^2)
lg = self.mu_lg + self.sigma_lg * torch.randn(K)
return lb, lg # shape (K,)
# ---- forward -----------------------------------------------------------
def forward(self, t_scaled, log_beta, log_gamma):
"""
t_scaled : (n,) tensor
log_beta : scalar tensor (0-d) — broadcast to all time points
log_gamma : scalar tensor (0-d)
Returns : (n, 3) — S, I, R fractions
"""
n = t_scaled.shape[0]
t_in = t_scaled.view(n, 1)
lb = log_beta.reshape(1).expand(n).view(n, 1)
lg = log_gamma.reshape(1).expand(n).view(n, 1)
return self.net(torch.cat([t_in, lb, lg], dim=1))5.2 Loss functions
def data_loss_fn(model, t_obs_s, I_obs_s, lb, lg):
"""MSE on the I compartment at observation times."""
pred = model(t_obs_s, lb, lg)
return ((pred[:, 1] - I_obs_s) ** 2).mean()
def physics_loss_fn(model, t_col, lb, lg):
"""
SIR ODE residuals computed via automatic differentiation.
Working in scaled time tau = t / T_max, the physical β and γ enter as
beta_s = exp(lb) * T_max
gamma_s = exp(lg) * T_max
so that d(compartment)/d(tau) = T_max * d(compartment)/dt_physical.
"""
t = t_col.detach().requires_grad_(True)
sir = model(t, lb, lg)
S, I, R = sir[:, 0], sir[:, 1], sir[:, 2]
ones = torch.ones(len(t))
dS = torch.autograd.grad(S, t, grad_outputs=ones, create_graph=True)[0]
dI = torch.autograd.grad(I, t, grad_outputs=ones, create_graph=True)[0]
dR = torch.autograd.grad(R, t, grad_outputs=ones, create_graph=True)[0]
b = torch.exp(lb) * T_max
g = torch.exp(lg) * T_max
res_S = dS + b * S * I
res_I = dI - b * S * I + g * I
res_R = dR - g * I
return (res_S ** 2 + res_I ** 2 + res_R ** 2).mean()
def ic_loss_fn(model, y0_s, lb, lg):
"""Match compartment fractions at t = 0."""
t0 = torch.tensor([0.0])
pred = model(t0, lb, lg)[0]
return ((pred - y0_s) ** 2).mean()
def kl_normal(mu_q, sig_q, mu_p, sig_p):
"""KL[ N(mu_q, sig_q^2) || N(mu_p, sig_p^2) ]."""
return (torch.log(sig_p / sig_q)
+ (sig_q ** 2 + (mu_q - mu_p) ** 2) / (2 * sig_p ** 2)
- 0.5)5.3 Training
# Weakly informative priors in log-parameter space
prior_mu_lb = torch.tensor(float(np.log(0.3)))
prior_mu_lg = torch.tensor(float(np.log(0.1)))
prior_sig = torch.tensor(0.5) # 95% prior mass: β ∈ [0.11, 0.81], γ ∈ [0.04, 0.27]
# Hyperparameters
lam_phys = 0.1
lam_ic = 5.0
K = 4 # MC samples per step
model = BayesianPINN(hidden=64)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=2_500, gamma=0.5)
losses = []
mu_b_hist = []
mu_g_hist = []
for epoch in range(8_000):
opt.zero_grad()
lb_k, lg_k = model.sample_log_params(K)
L_data_list, L_phys_list, L_ic_list = [], [], []
for k in range(K):
L_data_list.append(data_loss_fn (model, t_obs_s, I_obs_s, lb_k[k], lg_k[k]))
L_phys_list.append(physics_loss_fn(model, t_col, lb_k[k], lg_k[k]))
L_ic_list.append (ic_loss_fn (model, y0_s, lb_k[k], lg_k[k]))
L_data = torch.stack(L_data_list).mean()
L_phys = torch.stack(L_phys_list).mean()
L_ic = torch.stack(L_ic_list).mean()
# KL annealing: ramp lambda_KL from 0 → 0.3 over the first 3 000 epochs.
# This lets the data and physics terms establish a good trajectory before
# the prior regularisation tightens the posterior.
kl_w = min(epoch / 3_000, 1.0) * 0.3
KL = (kl_normal(model.mu_lb, model.sigma_lb, prior_mu_lb, prior_sig)
+ kl_normal(model.mu_lg, model.sigma_lg, prior_mu_lg, prior_sig))
loss = L_data + lam_phys * L_phys + lam_ic * L_ic + kl_w * KL
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sched.step()
losses.append(loss.item())
mu_b_hist.append(model.beta_mean)
mu_g_hist.append(model.gamma_mean)
print(f"True: β = {beta_true:.3f}, γ = {gamma_true:.3f}, R₀ = {beta_true/gamma_true:.2f}")
print(f"Posterior: β ≈ {model.beta_mean:.3f}, γ ≈ {model.gamma_mean:.3f}, R₀ ≈ {model.R0_mean:.2f}")
print(f"Post. σ: σ_β = {model.sigma_lb.item():.3f} (log scale), "
f"σ_γ = {model.sigma_lg.item():.3f} (log scale)")True: β = 0.300, γ = 0.100, R₀ = 3.00
Posterior: β ≈ 0.299, γ ≈ 0.100, R₀ ≈ 2.99
Post. σ: σ_β = 0.493 (log scale), σ_γ = 0.493 (log scale)
7 Results
7.1 Convergence of posterior means
The variational means start at \(\hat\beta = 0.5\), \(\hat\gamma = 0.15\) and converge toward the true values as the ELBO accumulates data and physics evidence.
fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
for ax, hist, true_val, label, color in [
(axes[0], mu_b_hist, beta_true, "β (posterior mean, day⁻¹)", C_I),
(axes[1], mu_g_hist, gamma_true, "γ (posterior mean, day⁻¹)", C_R),
]:
ax.plot(hist, color=color, lw=1.5, label="Variational mean")
ax.axhline(true_val, color="black", lw=1.5, ls="--",
label=f"True = {true_val}")
ax.set(xlabel="Epoch", ylabel=label,
title=f"Convergence: posterior mean of {label.split()[0]}")
ax.legend()
plt.tight_layout()
plt.show()7.2 Posterior predictive trajectories
All four panels show the full posterior predictive for observed counts: each trajectory sample has independent observation noise \(\varepsilon \sim \mathcal{N}(0,\,20^2)\) added before computing percentiles, so the bands reflect where future observations would scatter — not just uncertainty in the underlying epidemic curve.
From left to right: (1) mean-field VI (\(M = 500\) samples), (2) full-rank VI with bivariate-normal posterior (\(M = 500\)), (3) HMC (\(n = 500\) draws), (4) ensemble PINN (\(n = 10\) members, resampled \(500\times\)).
M = 500
t_viz_np = np.linspace(0, T_max, 300)
sol_viz = scipy_odeint(sir_rhs, y0, t_viz_np, args=(beta_true, gamma_true, N))
with torch.no_grad():
lb_samp, lg_samp = model.sample_log_params(M)
beta_samp = torch.exp(lb_samp).numpy()
gamma_samp = torch.exp(lg_samp).numpy()
# Solve the ODE for each B-PINN posterior sample
traj = np.zeros((M, len(t_viz_np), 3))
for m, (b, g) in enumerate(zip(beta_samp, gamma_samp)):
traj[m] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))
# Add observation noise to get the full posterior predictive for observed counts
np.random.seed(42)
I_noisy = traj[:, :, 1] + np.random.normal(0, 20, traj[:, :, 1].shape)
I_noisy = np.clip(I_noisy, 0, N)
I_lo = np.percentile(I_noisy, 2.5, axis=0)
I_med = np.percentile(I_noisy, 50, axis=0)
I_hi = np.percentile(I_noisy, 97.5, axis=0)# ── Ensemble PINN: 10 independent point-estimate PINNs ──────────────────────
class PointPINN(nn.Module):
def __init__(self, hidden=64, depth=4):
super().__init__()
layers = [nn.Linear(1, hidden), nn.Tanh()]
for _ in range(depth - 1):
layers += [nn.Linear(hidden, hidden), nn.Tanh()]
layers += [nn.Linear(hidden, 3), nn.Sigmoid()]
self.net = nn.Sequential(*layers)
self.log_beta_s = nn.Parameter(torch.tensor(float(np.log(10.0))))
self.log_gamma_s = nn.Parameter(torch.tensor(float(np.log(3.0))))
@property
def beta_phys(self): return torch.exp(self.log_beta_s) / T_max
@property
def gamma_phys(self): return torch.exp(self.log_gamma_s) / T_max
def forward(self, t_s):
return self.net(t_s.view(-1, 1))
def _ens_loss(m, t_obs_s, I_obs_s, t_col, y0_s):
sir = m(t_obs_s)
L_data = ((sir[:, 1] - I_obs_s) ** 2).mean()
t = t_col.detach().requires_grad_(True)
out = m(t); S, I, R = out[:, 0], out[:, 1], out[:, 2]
ones = torch.ones(len(t))
dS = torch.autograd.grad(S, t, grad_outputs=ones, create_graph=True)[0]
dI = torch.autograd.grad(I, t, grad_outputs=ones, create_graph=True)[0]
dR = torch.autograd.grad(R, t, grad_outputs=ones, create_graph=True)[0]
bs = torch.exp(m.log_beta_s); gs = torch.exp(m.log_gamma_s)
L_phys = ((dS + bs*S*I)**2 + (dI - bs*S*I + gs*I)**2 + (dR - gs*I)**2).mean()
sir0 = m(torch.tensor([0.0]))[0]
L_ic = ((sir0 - y0_s) ** 2).mean()
return L_data + 0.1 * L_phys + 5.0 * L_ic
n_ens = 10
beta_ens, gamma_ens = [], []
for seed in range(n_ens):
torch.manual_seed(seed + 200)
mp = PointPINN()
op = torch.optim.Adam(mp.parameters(), lr=1e-3)
sch = torch.optim.lr_scheduler.StepLR(op, step_size=2000, gamma=0.5)
for _ in range(6_000):
op.zero_grad()
_ens_loss(mp, t_obs_s, I_obs_s, t_col, y0_s).backward()
nn.utils.clip_grad_norm_(mp.parameters(), 1.0)
op.step(); sch.step()
beta_ens.append(mp.beta_phys.item())
gamma_ens.append(mp.gamma_phys.item())
beta_ens = np.array(beta_ens)
gamma_ens = np.array(gamma_ens)
traj_ens = np.zeros((n_ens, len(t_viz_np), 3))
for i, (b, g) in enumerate(zip(beta_ens, gamma_ens)):
traj_ens[i] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))
# Resample ensemble members 500× and add observation noise
np.random.seed(42)
idx_e = np.random.choice(n_ens, 500, replace=True)
I_ens_r = traj_ens[idx_e, :, 1] + np.random.normal(0, 20, (500, len(t_viz_np)))
I_ens_r = np.clip(I_ens_r, 0, N)
I_lo_e = np.percentile(I_ens_r, 2.5, axis=0)
I_med_e = np.percentile(I_ens_r, 50, axis=0)
I_hi_e = np.percentile(I_ens_r, 97.5, axis=0)
print(f"Ensemble β: mean={beta_ens.mean():.3f} std={beta_ens.std():.3f} "
f"(true {beta_true})")
print(f"Ensemble γ: mean={gamma_ens.mean():.3f} std={gamma_ens.std():.3f} "
f"(true {gamma_true})")Ensemble β: mean=0.295 std=0.000 (true 0.3)
Ensemble γ: mean=0.098 std=0.000 (true 0.1)
# ── HMC over (log β, log γ) via differentiable Euler ODE ────────────────────
# Uses a differentiable Euler integrator as the likelihood — no PINN surrogate
# needed, no extra dependencies. Gradients flow directly through the ODE steps.
sigma_obs = torch.tensor(20.0 / N) # observation noise in scaled units
def _sir_euler(lb, lg, n_steps=500):
"""
Euler integration of the SIR model in scaled time tau = t / T_max.
Returns I(t_obs_s) — fully differentiable w.r.t. lb and lg.
"""
beta_s = torch.exp(lb) * T_max
gamma_s = torch.exp(lg) * T_max
dt = 1.0 / n_steps
S, I, R = y0_s[0].clone(), y0_s[1].clone(), y0_s[2].clone()
I_grid = [I]
for _ in range(n_steps):
dS = -beta_s * S * I
dI = beta_s * S * I - gamma_s * I
dR = gamma_s * I
S, I, R = S + dt * dS, I + dt * dI, R + dt * dR
I_grid.append(I)
I_grid = torch.stack(I_grid) # (n_steps+1,)
obs_idx = (t_obs_s * n_steps).long().clamp(0, n_steps)
return I_grid[obs_idx]
def _log_post_hmc(lb, lg):
I_hat = _sir_euler(lb, lg)
log_lik = -0.5 * ((I_hat - I_obs_s)**2 / sigma_obs**2).sum()
log_pri = -0.5 * (((lb - prior_mu_lb)**2 + (lg - prior_mu_lg)**2)
/ prior_sig**2)
return log_lik + log_pri
def _grad_U(lb, lg):
lb_ = lb.detach().requires_grad_(True)
lg_ = lg.detach().requires_grad_(True)
(-_log_post_hmc(lb_, lg_)).backward()
return lb_.grad.detach(), lg_.grad.detach()
def _leapfrog_hmc(lb, lg, p_lb, p_lg, step, L):
lb, lg = lb.detach().clone(), lg.detach().clone()
g_lb, g_lg = _grad_U(lb, lg)
p_lb = p_lb - 0.5 * step * g_lb
p_lg = p_lg - 0.5 * step * g_lg
for i in range(L):
lb = lb + step * p_lb
lg = lg + step * p_lg
g_lb, g_lg = _grad_U(lb, lg)
factor = 0.5 if i == L - 1 else 1.0
p_lb = p_lb - factor * step * g_lb
p_lg = p_lg - factor * step * g_lg
return lb, lg, p_lb, p_lg
torch.manual_seed(0); np.random.seed(0)
lb_c = torch.tensor(float(np.log(0.3)))
lg_c = torch.tensor(float(np.log(0.1)))
curr_lp = _log_post_hmc(lb_c, lg_c).item()
n_warm_h, n_samp_h = 200, 500
step_h, L_h = 0.05, 15
hmc_lb, hmc_lg, n_acc = [], [], 0
for i in range(n_warm_h + n_samp_h):
p_lb = torch.randn(()); p_lg = torch.randn(())
H_c = -curr_lp + 0.5 * float(p_lb**2 + p_lg**2)
lb_p, lg_p, p_lb_p, p_lg_p = _leapfrog_hmc(lb_c, lg_c, p_lb, p_lg, step_h, L_h)
prop_lp = _log_post_hmc(lb_p, lg_p).item()
H_p = -prop_lp + 0.5 * float(p_lb_p**2 + p_lg_p**2)
if np.log(np.random.uniform() + 1e-15) < H_c - H_p:
lb_c, lg_c, curr_lp = lb_p, lg_p, prop_lp
n_acc += 1
if i >= n_warm_h:
hmc_lb.append(lb_c.item()); hmc_lg.append(lg_c.item())
beta_hmc = np.exp(hmc_lb); gamma_hmc = np.exp(hmc_lg)
print(f"HMC acceptance rate: {n_acc / (n_warm_h + n_samp_h):.2f}")
print(f"HMC β: mean={beta_hmc.mean():.3f} std={beta_hmc.std():.3f} (true {beta_true})")
print(f"HMC γ: mean={gamma_hmc.mean():.3f} std={gamma_hmc.std():.3f} (true {gamma_true})")
traj_hmc = np.zeros((n_samp_h, len(t_viz_np), 3))
for i, (b, g) in enumerate(zip(beta_hmc, gamma_hmc)):
traj_hmc[i] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))
np.random.seed(42)
I_hmc_noisy = traj_hmc[:, :, 1] + np.random.normal(0, 20, traj_hmc[:, :, 1].shape)
I_hmc_noisy = np.clip(I_hmc_noisy, 0, N)
I_lo_h = np.percentile(I_hmc_noisy, 2.5, axis=0)
I_med_h = np.percentile(I_hmc_noisy, 50, axis=0)
I_hi_h = np.percentile(I_hmc_noisy, 97.5, axis=0)HMC acceptance rate: 0.00
HMC β: mean=0.300 std=0.000 (true 0.3)
HMC γ: mean=0.100 std=0.000 (true 0.1)
# ── Full-rank VI posterior predictive ────────────────────────────────────────
with torch.no_grad():
fr_lb_samp, fr_lg_samp = fr_model.sample_log_params(M)
fr_beta_samp = torch.exp(fr_lb_samp).numpy()
fr_gamma_samp = torch.exp(fr_lg_samp).numpy()
fr_traj = np.zeros((M, len(t_viz_np), 3))
for m, (b, g) in enumerate(zip(fr_beta_samp, fr_gamma_samp)):
fr_traj[m] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))
np.random.seed(42)
fr_I_noisy = fr_traj[:, :, 1] + np.random.normal(0, 20, fr_traj[:, :, 1].shape)
fr_I_noisy = np.clip(fr_I_noisy, 0, N)
I_lo_fr = np.percentile(fr_I_noisy, 2.5, axis=0)
I_med_fr = np.percentile(fr_I_noisy, 50, axis=0)
I_hi_fr = np.percentile(fr_I_noisy, 97.5, axis=0)fig, axes = plt.subplots(1, 4, figsize=(12, 3.5))
for ax, lo, med, hi, title in [
(axes[0], I_lo, I_med, I_hi, "VI mean-field (M=500)"),
(axes[1], I_lo_fr, I_med_fr, I_hi_fr, "VI full-rank (M=500)"),
(axes[2], I_lo_h, I_med_h, I_hi_h, f"HMC (n={n_samp_h})"),
(axes[3], I_lo_e, I_med_e, I_hi_e, f"Ensemble PINN (n={n_ens})"),
]:
ax.fill_between(t_viz_np, lo, hi, alpha=0.25, color=C_PI, label="95% PI (+ obs noise)")
ax.plot(t_viz_np, med, color=C_POST, lw=1.5, label="Median")
ax.plot(t_viz_np, sol_viz[:, 1], color=C_I, lw=1.5, ls="--",
alpha=0.7, label="True I(t)")
ax.scatter(t_obs, I_obs, color=C_OBS, s=12, zorder=5, label="Observed")
ax.set(xlabel="Day", ylabel="Infectious (I)", title=title)
ax.legend()
plt.tight_layout()
plt.show()Comparing the four methods
Mean-field VI assumes \(q(\beta, \gamma) = q(\beta) \cdot q(\gamma)\) — independent marginals. But \(I(t)\) data constrain \(\beta/\gamma \approx R_0\) more tightly than either parameter alone, creating a strong negative correlation in the joint posterior; mean-field inflates marginal variances to compensate, widening the predictive envelope. Full-rank VI adds one off-diagonal Cholesky parameter (\(L_{21}\)) to capture this correlation directly — its predictive interval narrows substantially and the median tracks the true curve more closely. HMC explores the posterior ridge exactly via gradient-driven leapfrog proposals and serves as the gold-standard reference. The Ensemble PINN (4) bypasses the posterior entirely — each PINN converges to a point near the true parameters, so the spread reflects optimisation variability rather than principled Bayesian uncertainty.
7.3 Marginal posteriors
R0_samp = beta_samp / gamma_samp
fig, axes = plt.subplots(1, 3, figsize=(7, 3))
for ax, samples, true_val, label, color in [
(axes[0], beta_samp, beta_true, r"$\beta$ (day⁻¹)", C_I),
(axes[1], gamma_samp, gamma_true, r"$\gamma$ (day⁻¹)", C_R),
(axes[2], R0_samp, beta_true/gamma_true, r"$R_0 = \beta/\gamma$", C_POST),
]:
counts, _, _ = ax.hist(samples, bins=40, color=color,
alpha=0.7, edgecolor="white")
ax.axvline(true_val, color="black", lw=2, ls="--",
label=f"True = {true_val:.2f}")
ax.axvline(np.mean(samples), color=color, lw=2,
label=f"Post. mean = {np.mean(samples):.2f}")
ax.set(xlabel=label, ylabel="Count")
ax.legend()
axes[0].set_title(r"Posterior: $\beta$")
axes[1].set_title(r"Posterior: $\gamma$")
axes[2].set_title(r"Posterior: $R_0$")
plt.suptitle("Bayesian PINN — marginal posteriors", y=1.02)
plt.tight_layout()
plt.show()
print(f"β — true: {beta_true:.3f} | "
f"post. mean: {beta_samp.mean():.3f} "
f"[{np.percentile(beta_samp,2.5):.3f}, {np.percentile(beta_samp,97.5):.3f}]")
print(f"γ — true: {gamma_true:.3f} | "
f"post. mean: {gamma_samp.mean():.3f} "
f"[{np.percentile(gamma_samp,2.5):.3f}, {np.percentile(gamma_samp,97.5):.3f}]")
print(f"R₀ — true: {beta_true/gamma_true:.2f} | "
f"post. mean: {R0_samp.mean():.2f} "
f"[{np.percentile(R0_samp,2.5):.2f}, {np.percentile(R0_samp,97.5):.2f}]")β — true: 0.300 | post. mean: 0.341 [0.126, 0.795]
γ — true: 0.100 | post. mean: 0.117 [0.034, 0.286]
R₀ — true: 3.00 | post. mean: 3.93 [0.73, 12.46]
8 Connection to policy decisions
The posterior on \(R_0 = \beta / \gamma\) connects directly to the pre-emptive vs. reactive OCV stockpile decision. The critical threshold from the single-subunit model is:
\[ p_\text{crit} = \frac{r}{r + R(\nu - r)} \]
where \(p\) is the per-district outbreak probability, \(r\) is reactive effectiveness, \(\nu\) is pre-emptive effectiveness, and \(R\) is the cost ratio. Pre-emptive vaccination is preferred when \(p > p_\text{crit}\).
The Bayesian PINN contributes the distribution of \(R_0\) inferred from surveillance data. Where the posterior on \(R_0\) is wide — as in early-epidemic or low-surveillance settings — the uncertainty in \(p_\text{crit}\) is large, and the value of additional risk intelligence (a wider or better surveillance system) is correspondingly high.
threshold_R0 = 2.5
p_above = float((R0_samp > threshold_R0).mean())
fig, ax = plt.subplots(figsize=(5, 3.5))
counts, _, _ = ax.hist(R0_samp, bins=40, color=C_POST, alpha=0.7, edgecolor="white")
ylim_top = counts.max() * 1.15
ax.axvline(beta_true / gamma_true, color=C_I, lw=2,
label=f"True R₀ = {beta_true/gamma_true:.1f}")
ax.axvline(threshold_R0, color="black", lw=2, ls="--",
label=f"R₀ threshold = {threshold_R0}")
ax.fill_betweenx([0, ylim_top], threshold_R0, R0_samp.max() + 0.3,
alpha=0.15, color="red",
label=f"P(R₀ > {threshold_R0}) = {p_above:.2f}")
ax.set_ylim(0, ylim_top)
ax.set(xlabel=r"$R_0$", ylabel="Count",
title=r"Posterior on $R_0$ — decision-relevant probability mass")
ax.legend()
plt.tight_layout()
plt.show()The policy payoff
A point-estimate PINN reports: “\(R_0 = 3.0\), therefore pre-emptive is preferred.” A Bayesian PINN reports: “\(P(R_0 > 2.5) = 0.78\), so there is a 22% chance the outbreak is mild enough that reactive vaccination would have sufficed.”
These are qualitatively different inputs for a decision-maker allocating a fixed ICG stockpile across competing requests. The distribution, not the point, is what enters the expected-loss calculation.
9 What the variational posterior does not capture
| Limitation | Impact | Remedy |
|---|---|---|
| Mean-field: \(q(\beta,\gamma) = q(\beta)\cdot q(\gamma)\) | Misses \(\beta\)-\(\gamma\) correlation | Full-rank VI with Cholesky covariance — implemented above |
| Point-estimate network weights \(\theta\) | No uncertainty from network architecture | Full B-PINN (1): weight distributions via BNN |
| Log-normal variational family | Poor fit if posterior is multimodal | Normalizing flows (5) |
| \(K = 2\) MC samples per step | Noisy ELBO gradient | Increase \(K\); use importance-weighted ELBO |
The point-estimate network is the most consequential simplification. In the full B-PINN of Yang et al. (1), the network weights also carry distributions, giving a richer notion of epistemic uncertainty (uncertainty from limited data) on top of the aleatoric uncertainty (irreducible noise) captured here.
10 Extensions
- Normalizing flows (5): a sequence of invertible transforms on a Gaussian base → arbitrarily complex posterior geometry, no mean-field assumption.
- MCMC via Pyro/NumPyro: replace VI with exact HMC/NUTS; use the trained PINN as a differentiable, physics-consistent likelihood. VI initialises the sampler near the posterior mode to accelerate mixing.
- Sequential (online) inference: update \(q(\beta, \gamma)\) as new surveillance data arrive during an ongoing outbreak — the posterior from week \(t\) becomes the prior for week \(t+1\), yielding a real-time Bayesian nowcast.
- SEIR/SEIRS extension: the conditional network architecture generalises directly; add incubation rate \(\sigma\) as a third variational parameter and expand the network input to \((t,\,\log\beta,\,\log\gamma,\,\log\sigma)\).
11 References
Python 3.12.10
torch 2.11.0+cpu
scipy 1.17.1
matplotlib 3.10.8