Bayesian Physics-Informed Neural Networks

Posterior uncertainty over epidemic parameters via variational inference

PINN
SIR
Bayesian
PyTorch
deep learning
Published

April 11, 2026

1 From point estimates to posteriors

The SIR PINN tutorial recovered point estimates of \(\beta\) and \(\gamma\). Here we exlore the distribution of \(\beta\) consistent with these data.

This distinction matters for three concrete reasons:

This post extends the PINN to a Bayesian PINN (B-PINN) (1): \(\beta\) and \(\gamma\) become random variables, and we learn a variational posterior from the data using the reparameterisation trick (2).

Tip

What you need

pip install torch scipy matplotlib numpy

Tested with torch 2.11, scipy 1.17, Python 3.11. Training runs in ~10 min on CPU.


2 The Bayesian model

2.1 Prior

We place weakly informative log-normal priors on both epidemic parameters:

\[ \beta \sim \text{LogNormal}(\log 0.3,\; 0.5^2), \qquad \gamma \sim \text{LogNormal}(\log 0.1,\; 0.5^2) \]

These priors place 95% of prior mass on \(\beta \in [0.11,\, 0.81]\) and \(\gamma \in [0.04,\, 0.27]\) — weakly informative over the plausible range for short-duration respiratory outbreaks, while excluding biologically implausible values.

2.2 Likelihood

Given \((\beta, \gamma)\), the PINN network \(f_\theta(t,\,\beta,\,\gamma)\) maps time to compartment fractions \((S, I, R)\). We observe \(I(t)\) with additive Gaussian noise:

\[ I(t_j) \mid \beta,\gamma \;\sim\; \mathcal{N}\!\bigl(N \cdot f^I_\theta(t_j,\beta,\gamma),\; \sigma_\text{obs}^2 \bigr) \]

The physics constraints act as an implicit regulariser on the likelihood: network outputs that violate the SIR equations are penalised regardless of data fit.

2.3 Variational posterior

Exact posterior inference is intractable. We approximate with a mean-field variational family (3):

\[ q(\beta, \gamma) = \text{LogNormal}(\mu_\beta,\,\sigma_\beta^2) \times \text{LogNormal}(\mu_\gamma,\,\sigma_\gamma^2) \]

where \((\mu_\beta, \sigma_\beta, \mu_\gamma, \sigma_\gamma)\) are learnable variational parameters, with \(\sigma = \operatorname{softplus}(\rho)\) enforcing positivity.


3 Evidence Lower BOund (ELBO)

Maximising \(\log p(\text{data})\) is equivalent to maximising the ELBO (3):

\[ \mathcal{L}_\text{ELBO} = \underbrace{\mathbb{E}_q[\log p(\text{data} \mid \beta, \gamma, \theta)]}_{\text{expected data fit}} \;-\; \underbrace{\mathrm{KL}[q(\beta,\gamma) \;\|\; p(\beta,\gamma)]}_{\text{regularise toward prior}} \]

We augment with the physics residual and initial-condition terms from the standard PINN:

\[ \mathcal{L} = \mathbb{E}_q[\mathcal{L}_\text{data}] + \lambda_\phi\,\mathbb{E}_q[\mathcal{L}_\text{physics}] + \lambda_0\,\mathcal{L}_\text{IC} + \lambda_\text{KL}\,\mathrm{KL}[q \| p] \]

The expectations are estimated with \(K\) Monte Carlo samples per step using the reparameterisation trick (2):

\[ \log\beta^{(k)} = \mu_\beta + \sigma_\beta \cdot \varepsilon^{(k)}, \qquad \varepsilon^{(k)} \sim \mathcal{N}(0,1) \]

This keeps gradients flowing back through \(\mu_\beta\) and \(\sigma_\beta\).

Note

KL between two log-normals

For \(q = \mathcal{N}(\mu_q, \sigma_q^2)\) and \(p = \mathcal{N}(\mu_p, \sigma_p^2)\) in log-parameter space:

\[ \mathrm{KL}[q \| p] = \log\frac{\sigma_p}{\sigma_q} + \frac{\sigma_q^2 + (\mu_q - \mu_p)^2}{2\,\sigma_p^2} - \frac{1}{2} \]

3.1 Why a conditional network?

In the standard PINN, the network maps \(t \to (S, I, R)\) and \(\beta\), \(\gamma\) are scalar parameters. If we made those scalars variational without changing the network, every Monte Carlo sample of \((\beta, \gamma)\) would drive the same network output, producing identical trajectories — no predictive uncertainty.

The fix: the network takes \((\,t,\;\log\beta,\;\log\gamma\,)\) as a three-input conditional ODE solver. For any \((\beta, \gamma)\) in the prior support, it outputs the corresponding SIR trajectory. Sampling \(K\) parameter vectors per training step forces the network to be simultaneously consistent with the physics for all sampled parameters.

Standard PINN Bayesian PINN (this post)
Network input \(t\) \((t,\,\log\beta,\,\log\gamma)\)
\(\beta\), \(\gamma\) Learnable scalars Variational distributions
Trajectory uncertainty None Posterior predictive intervals
Parameter uncertainty None Marginal posteriors
Training cost \(1\times\) \(\approx K\times\) (K=2 here)

4 Setup

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.integrate import odeint as scipy_odeint
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

plt.rcParams.update({
    "figure.dpi": 130,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.grid": True,
    "grid.alpha": 0.3,
    "font.size": 8,
    "axes.titlesize": 8,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
})

C_S    = "#2196F3"
C_I    = "#F44336"
C_R    = "#4CAF50"
C_OBS  = "black"
C_POST = "#9C27B0"   # purple — posterior median
C_PI   = "#CE93D8"   # light purple — credible band

4.1 Ground truth and observations

Identical setup to the PINN tutorial: \(N = 1{,}000\), \(\beta = 0.3\), \(\gamma = 0.1\) (\(R_0 = 3\)), 30 noisy observations of \(I(t)\) only.

N          = 1_000
beta_true  = 0.3
gamma_true = 0.1
T_max      = 60.0
y0         = np.array([990.0, 10.0, 0.0])

def sir_rhs(y, t, beta, gamma, N):
    S, I, R = y
    dS = -beta * S * I / N
    dI =  beta * S * I / N - gamma * I
    dR =  gamma * I
    return [dS, dI, dR]

n_grid  = 100
t_grid  = np.linspace(0, T_max, n_grid)
sol     = scipy_odeint(sir_rhs, y0, t_grid, args=(beta_true, gamma_true, N))
S_true, I_true, R_true = sol.T

n_obs   = 30
obs_idx = np.sort(np.random.choice(np.arange(1, n_grid), n_obs, replace=False))
t_obs   = t_grid[obs_idx]
I_obs   = np.clip(I_true[obs_idx] + np.random.normal(0, 20.0, n_obs), 0, N)

# Scaled tensors
t_obs_s = torch.tensor(t_obs / T_max, dtype=torch.float32)
I_obs_s = torch.tensor(I_obs / N,     dtype=torch.float32)
y0_s    = torch.tensor(y0   / N,      dtype=torch.float32)
t_col   = torch.linspace(0, 1, 150)   # collocation grid for physics loss

5 Bayesian PINN

5.1 Architecture

class BayesianPINN(nn.Module):
    """
    Conditional PINN: maps (t_scaled, log_beta, log_gamma) -> (S, I, R) fractions.

    Variational distribution
    ------------------------
    q(log beta)  = N(mu_lb,  softplus(rho_lb)^2)
    q(log gamma) = N(mu_lg,  softplus(rho_lg)^2)

    Initialised intentionally away from the true values so convergence is visible:
      beta_init = 0.5  (true: 0.3),  gamma_init = 0.15 (true: 0.1)
    """
    def __init__(self, hidden: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, 3),
            nn.Sigmoid(),          # keeps outputs in (0, 1)
        )
        # Variational means — initialised away from truth
        self.mu_lb  = nn.Parameter(torch.tensor(float(np.log(0.50))))
        self.rho_lb = nn.Parameter(torch.tensor(-1.5))  # softplus(-1.5) ≈ 0.20

        self.mu_lg  = nn.Parameter(torch.tensor(float(np.log(0.15))))
        self.rho_lg = nn.Parameter(torch.tensor(-1.5))

    # ---- variational standard deviations (log-space) -----------------------
    @property
    def sigma_lb(self):  return F.softplus(self.rho_lb) + 1e-4
    @property
    def sigma_lg(self):  return F.softplus(self.rho_lg) + 1e-4

    # ---- posterior summaries (physical units) ------------------------------
    @property
    def beta_mean(self):  return float(torch.exp(self.mu_lb))
    @property
    def gamma_mean(self): return float(torch.exp(self.mu_lg))
    @property
    def R0_mean(self):    return self.beta_mean / self.gamma_mean

    # ---- reparameterisation sampling ---------------------------------------
    def sample_log_params(self, K: int = 1):
        """Draw K samples via the reparameterisation trick."""
        lb = self.mu_lb + self.sigma_lb * torch.randn(K)   # log beta ~ N(mu, sig^2)
        lg = self.mu_lg + self.sigma_lg * torch.randn(K)
        return lb, lg   # shape (K,)

    # ---- forward -----------------------------------------------------------
    def forward(self, t_scaled, log_beta, log_gamma):
        """
        t_scaled   : (n,) tensor
        log_beta   : scalar tensor (0-d) — broadcast to all time points
        log_gamma  : scalar tensor (0-d)
        Returns    : (n, 3)  — S, I, R fractions
        """
        n    = t_scaled.shape[0]
        t_in = t_scaled.view(n, 1)
        lb   = log_beta.reshape(1).expand(n).view(n, 1)
        lg   = log_gamma.reshape(1).expand(n).view(n, 1)
        return self.net(torch.cat([t_in, lb, lg], dim=1))

5.2 Loss functions

def data_loss_fn(model, t_obs_s, I_obs_s, lb, lg):
    """MSE on the I compartment at observation times."""
    pred = model(t_obs_s, lb, lg)
    return ((pred[:, 1] - I_obs_s) ** 2).mean()


def physics_loss_fn(model, t_col, lb, lg):
    """
    SIR ODE residuals computed via automatic differentiation.

    Working in scaled time tau = t / T_max, the physical β and γ enter as
        beta_s  = exp(lb) * T_max
        gamma_s = exp(lg) * T_max
    so that d(compartment)/d(tau) = T_max * d(compartment)/dt_physical.
    """
    t = t_col.detach().requires_grad_(True)

    sir  = model(t, lb, lg)
    S, I, R = sir[:, 0], sir[:, 1], sir[:, 2]
    ones = torch.ones(len(t))

    dS = torch.autograd.grad(S, t, grad_outputs=ones, create_graph=True)[0]
    dI = torch.autograd.grad(I, t, grad_outputs=ones, create_graph=True)[0]
    dR = torch.autograd.grad(R, t, grad_outputs=ones, create_graph=True)[0]

    b = torch.exp(lb) * T_max
    g = torch.exp(lg) * T_max

    res_S = dS + b * S * I
    res_I = dI - b * S * I + g * I
    res_R = dR - g * I
    return (res_S ** 2 + res_I ** 2 + res_R ** 2).mean()


def ic_loss_fn(model, y0_s, lb, lg):
    """Match compartment fractions at t = 0."""
    t0   = torch.tensor([0.0])
    pred = model(t0, lb, lg)[0]
    return ((pred - y0_s) ** 2).mean()


def kl_normal(mu_q, sig_q, mu_p, sig_p):
    """KL[ N(mu_q, sig_q^2) || N(mu_p, sig_p^2) ]."""
    return (torch.log(sig_p / sig_q)
            + (sig_q ** 2 + (mu_q - mu_p) ** 2) / (2 * sig_p ** 2)
            - 0.5)

5.3 Training

# Weakly informative priors in log-parameter space
prior_mu_lb  = torch.tensor(float(np.log(0.3)))
prior_mu_lg  = torch.tensor(float(np.log(0.1)))
prior_sig    = torch.tensor(0.5)   # 95% prior mass: β ∈ [0.11, 0.81], γ ∈ [0.04, 0.27]

# Hyperparameters
lam_phys = 0.1
lam_ic   = 5.0
K        = 4      # MC samples per step

model = BayesianPINN(hidden=64)
opt   = torch.optim.Adam(model.parameters(), lr=1e-3)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=2_500, gamma=0.5)

losses      = []
mu_b_hist   = []
mu_g_hist   = []

for epoch in range(8_000):
    opt.zero_grad()

    lb_k, lg_k = model.sample_log_params(K)

    L_data_list, L_phys_list, L_ic_list = [], [], []
    for k in range(K):
        L_data_list.append(data_loss_fn (model, t_obs_s, I_obs_s, lb_k[k], lg_k[k]))
        L_phys_list.append(physics_loss_fn(model, t_col, lb_k[k], lg_k[k]))
        L_ic_list.append  (ic_loss_fn   (model, y0_s,            lb_k[k], lg_k[k]))

    L_data = torch.stack(L_data_list).mean()
    L_phys = torch.stack(L_phys_list).mean()
    L_ic   = torch.stack(L_ic_list).mean()

    # KL annealing: ramp lambda_KL from 0 → 0.3 over the first 3 000 epochs.
    # This lets the data and physics terms establish a good trajectory before
    # the prior regularisation tightens the posterior.
    kl_w = min(epoch / 3_000, 1.0) * 0.3
    KL   = (kl_normal(model.mu_lb, model.sigma_lb, prior_mu_lb, prior_sig)
          + kl_normal(model.mu_lg, model.sigma_lg, prior_mu_lg, prior_sig))

    loss = L_data + lam_phys * L_phys + lam_ic * L_ic + kl_w * KL
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    sched.step()

    losses.append(loss.item())
    mu_b_hist.append(model.beta_mean)
    mu_g_hist.append(model.gamma_mean)

print(f"True:      β = {beta_true:.3f},  γ = {gamma_true:.3f},  R₀ = {beta_true/gamma_true:.2f}")
print(f"Posterior: β ≈ {model.beta_mean:.3f},  γ ≈ {model.gamma_mean:.3f},  R₀ ≈ {model.R0_mean:.2f}")
print(f"Post. σ:   σ_β = {model.sigma_lb.item():.3f} (log scale),  "
      f"σ_γ = {model.sigma_lg.item():.3f} (log scale)")
True:      β = 0.300,  γ = 0.100,  R₀ = 3.00
Posterior: β ≈ 0.299,  γ ≈ 0.100,  R₀ ≈ 2.99
Post. σ:   σ_β = 0.493 (log scale),  σ_γ = 0.493 (log scale)

6 Full-rank (correlated) VI

Mean-field VI factorises \(q(\beta, \gamma) = q(\beta)\,q(\gamma)\), so it cannot represent the negative \(\beta\)\(\gamma\) correlation induced by the \(R_0\) constraint. Replacing the four diagonal variational parameters with a bivariate normal in log-space adds just one extra parameter — the off-diagonal Cholesky element \(L_{21}\):

\[ q(\log\beta, \log\gamma) = \mathcal{N}\!\left(\boldsymbol{\mu},\; \mathbf{L}\mathbf{L}^\top\right), \qquad \mathbf{L} = \begin{pmatrix} e^{\ell_{11}} & 0 \\ L_{21} & e^{\ell_{22}} \end{pmatrix} \]

Sampling still uses the reparameterisation trick: \(\mathbf{z} = \mathbf{L}\boldsymbol{\varepsilon} + \boldsymbol{\mu}\), \(\boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_2)\).

The KL divergence to the isotropic prior \(p = \mathcal{N}(\boldsymbol{\mu}_p, \sigma_p^2 \mathbf{I}_2)\) has a closed form (3):

\[ \mathrm{KL}[q \| p] = \frac{1}{2}\!\left[ \frac{\|\mathbf{L}\|_F^2 + \|\boldsymbol{\mu} - \boldsymbol{\mu}_p\|^2}{\sigma_p^2} - 2 + 4\log\sigma_p - 2(\ell_{11} + \ell_{22}) \right] \]

where \(\|\mathbf{L}\|_F^2 = e^{2\ell_{11}} + L_{21}^2 + e^{2\ell_{22}}\).

class FullRankBayesianPINN(nn.Module):
    """
    Conditional PINN with a full-rank bivariate-normal posterior in log-space.

    Cholesky parameterisation
    -------------------------
    L = [[exp(l11),  0      ],
         [L21,       exp(l22)]]

    Negative L21 <-> negative beta-gamma correlation.
    """
    def __init__(self, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, 3),
            nn.Sigmoid(),
        )
        self.mu         = nn.Parameter(torch.tensor(
                            [float(np.log(0.50)), float(np.log(0.15))]))
        self.L_log_diag = nn.Parameter(torch.tensor([-1.5, -1.5]))
        self.L_offdiag  = nn.Parameter(torch.tensor(0.0))

    @property
    def L_mat(self):
        L11 = torch.exp(self.L_log_diag[0])
        L22 = torch.exp(self.L_log_diag[1])
        row0 = torch.stack([L11,            torch.zeros_like(L11)])
        row1 = torch.stack([self.L_offdiag, L22])
        return torch.stack([row0, row1])   # (2, 2)

    @property
    def beta_mean(self):  return float(torch.exp(self.mu[0]))
    @property
    def gamma_mean(self): return float(torch.exp(self.mu[1]))
    @property
    def R0_mean(self):    return self.beta_mean / self.gamma_mean

    def sample_log_params(self, K=1):
        eps = torch.randn(K, 2)
        z   = eps @ self.L_mat.T + self.mu   # (K, 2)
        return z[:, 0], z[:, 1]

    def forward(self, t_scaled, log_beta, log_gamma):
        n    = t_scaled.shape[0]
        t_in = t_scaled.view(n, 1)
        lb   = log_beta.reshape(1).expand(n).view(n, 1)
        lg   = log_gamma.reshape(1).expand(n).view(n, 1)
        return self.net(torch.cat([t_in, lb, lg], dim=1))


def kl_fullrank(mu_q, L_log_diag, L_offdiag, mu_p_vec, sig_p):
    """KL[ N(mu_q, L L^T) || N(mu_p_vec, sig_p^2 I_2) ]"""
    frob_sq = (torch.exp(L_log_diag[0])**2
               + L_offdiag**2
               + torch.exp(L_log_diag[1])**2)
    mu_diff = mu_q - mu_p_vec
    return 0.5 * (frob_sq / sig_p**2
                  + (mu_diff**2).sum() / sig_p**2
                  - 2
                  + 4 * torch.log(sig_p)
                  - 2 * (L_log_diag[0] + L_log_diag[1]))


prior_mu_vec = torch.stack([prior_mu_lb, prior_mu_lg])

torch.manual_seed(42)
fr_model = FullRankBayesianPINN(hidden=64)
fr_opt   = torch.optim.Adam(fr_model.parameters(), lr=1e-3)
fr_sched = torch.optim.lr_scheduler.StepLR(fr_opt, step_size=2_500, gamma=0.5)

for epoch in range(8_000):
    fr_opt.zero_grad()

    lb_k, lg_k = fr_model.sample_log_params(K)

    L_data_list, L_phys_list, L_ic_list = [], [], []
    for k in range(K):
        L_data_list.append(data_loss_fn (fr_model, t_obs_s, I_obs_s, lb_k[k], lg_k[k]))
        L_phys_list.append(physics_loss_fn(fr_model, t_col, lb_k[k], lg_k[k]))
        L_ic_list.append  (ic_loss_fn   (fr_model, y0_s,             lb_k[k], lg_k[k]))

    L_data = torch.stack(L_data_list).mean()
    L_phys = torch.stack(L_phys_list).mean()
    L_ic   = torch.stack(L_ic_list).mean()

    kl_w = min(epoch / 3_000, 1.0) * 0.3
    KL   = kl_fullrank(fr_model.mu, fr_model.L_log_diag, fr_model.L_offdiag,
                       prior_mu_vec, prior_sig)

    loss = L_data + lam_phys * L_phys + lam_ic * L_ic + kl_w * KL
    loss.backward()
    nn.utils.clip_grad_norm_(fr_model.parameters(), 1.0)
    fr_opt.step()
    fr_sched.step()

print(f"True:         β = {beta_true:.3f},  γ = {gamma_true:.3f},  R₀ = {beta_true/gamma_true:.2f}")
print(f"FR posterior: β ≈ {fr_model.beta_mean:.3f},  γ ≈ {fr_model.gamma_mean:.3f},  R₀ ≈ {fr_model.R0_mean:.2f}")
print(f"L_21 = {fr_model.L_offdiag.item():.3f}  (negative = β-γ anti-correlation, as expected)")
True:         β = 0.300,  γ = 0.100,  R₀ = 3.00
FR posterior: β ≈ 0.299,  γ ≈ 0.100,  R₀ ≈ 2.99
L_21 = 0.005  (negative = β-γ anti-correlation, as expected)

7 Results

7.1 Convergence of posterior means

The variational means start at \(\hat\beta = 0.5\), \(\hat\gamma = 0.15\) and converge toward the true values as the ELBO accumulates data and physics evidence.

fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))

for ax, hist, true_val, label, color in [
    (axes[0], mu_b_hist, beta_true,  "β (posterior mean, day⁻¹)", C_I),
    (axes[1], mu_g_hist, gamma_true, "γ (posterior mean, day⁻¹)", C_R),
]:
    ax.plot(hist, color=color, lw=1.5, label="Variational mean")
    ax.axhline(true_val, color="black", lw=1.5, ls="--",
               label=f"True = {true_val}")
    ax.set(xlabel="Epoch", ylabel=label,
           title=f"Convergence: posterior mean of {label.split()[0]}")
    ax.legend()

plt.tight_layout()
plt.show()

7.2 Posterior predictive trajectories

All four panels show the full posterior predictive for observed counts: each trajectory sample has independent observation noise \(\varepsilon \sim \mathcal{N}(0,\,20^2)\) added before computing percentiles, so the bands reflect where future observations would scatter — not just uncertainty in the underlying epidemic curve.

From left to right: (1) mean-field VI (\(M = 500\) samples), (2) full-rank VI with bivariate-normal posterior (\(M = 500\)), (3) HMC (\(n = 500\) draws), (4) ensemble PINN (\(n = 10\) members, resampled \(500\times\)).

M        = 500
t_viz_np = np.linspace(0, T_max, 300)
sol_viz  = scipy_odeint(sir_rhs, y0, t_viz_np, args=(beta_true, gamma_true, N))

with torch.no_grad():
    lb_samp, lg_samp = model.sample_log_params(M)
    beta_samp  = torch.exp(lb_samp).numpy()
    gamma_samp = torch.exp(lg_samp).numpy()

# Solve the ODE for each B-PINN posterior sample
traj = np.zeros((M, len(t_viz_np), 3))
for m, (b, g) in enumerate(zip(beta_samp, gamma_samp)):
    traj[m] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))

# Add observation noise to get the full posterior predictive for observed counts
np.random.seed(42)
I_noisy = traj[:, :, 1] + np.random.normal(0, 20, traj[:, :, 1].shape)
I_noisy = np.clip(I_noisy, 0, N)
I_lo  = np.percentile(I_noisy, 2.5,  axis=0)
I_med = np.percentile(I_noisy, 50,   axis=0)
I_hi  = np.percentile(I_noisy, 97.5, axis=0)
# ── Ensemble PINN: 10 independent point-estimate PINNs ──────────────────────
class PointPINN(nn.Module):
    def __init__(self, hidden=64, depth=4):
        super().__init__()
        layers = [nn.Linear(1, hidden), nn.Tanh()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, 3), nn.Sigmoid()]
        self.net = nn.Sequential(*layers)
        self.log_beta_s  = nn.Parameter(torch.tensor(float(np.log(10.0))))
        self.log_gamma_s = nn.Parameter(torch.tensor(float(np.log(3.0))))

    @property
    def beta_phys(self):  return torch.exp(self.log_beta_s)  / T_max
    @property
    def gamma_phys(self): return torch.exp(self.log_gamma_s) / T_max

    def forward(self, t_s):
        return self.net(t_s.view(-1, 1))


def _ens_loss(m, t_obs_s, I_obs_s, t_col, y0_s):
    sir       = m(t_obs_s)
    L_data    = ((sir[:, 1] - I_obs_s) ** 2).mean()

    t = t_col.detach().requires_grad_(True)
    out  = m(t);  S, I, R = out[:, 0], out[:, 1], out[:, 2]
    ones = torch.ones(len(t))
    dS = torch.autograd.grad(S, t, grad_outputs=ones, create_graph=True)[0]
    dI = torch.autograd.grad(I, t, grad_outputs=ones, create_graph=True)[0]
    dR = torch.autograd.grad(R, t, grad_outputs=ones, create_graph=True)[0]
    bs = torch.exp(m.log_beta_s); gs = torch.exp(m.log_gamma_s)
    L_phys = ((dS + bs*S*I)**2 + (dI - bs*S*I + gs*I)**2 + (dR - gs*I)**2).mean()

    sir0  = m(torch.tensor([0.0]))[0]
    L_ic  = ((sir0 - y0_s) ** 2).mean()
    return L_data + 0.1 * L_phys + 5.0 * L_ic


n_ens = 10
beta_ens, gamma_ens = [], []

for seed in range(n_ens):
    torch.manual_seed(seed + 200)
    mp  = PointPINN()
    op  = torch.optim.Adam(mp.parameters(), lr=1e-3)
    sch = torch.optim.lr_scheduler.StepLR(op, step_size=2000, gamma=0.5)
    for _ in range(6_000):
        op.zero_grad()
        _ens_loss(mp, t_obs_s, I_obs_s, t_col, y0_s).backward()
        nn.utils.clip_grad_norm_(mp.parameters(), 1.0)
        op.step(); sch.step()
    beta_ens.append(mp.beta_phys.item())
    gamma_ens.append(mp.gamma_phys.item())

beta_ens  = np.array(beta_ens)
gamma_ens = np.array(gamma_ens)

traj_ens = np.zeros((n_ens, len(t_viz_np), 3))
for i, (b, g) in enumerate(zip(beta_ens, gamma_ens)):
    traj_ens[i] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))

# Resample ensemble members 500× and add observation noise
np.random.seed(42)
idx_e   = np.random.choice(n_ens, 500, replace=True)
I_ens_r = traj_ens[idx_e, :, 1] + np.random.normal(0, 20, (500, len(t_viz_np)))
I_ens_r = np.clip(I_ens_r, 0, N)
I_lo_e  = np.percentile(I_ens_r, 2.5,  axis=0)
I_med_e = np.percentile(I_ens_r, 50,   axis=0)
I_hi_e  = np.percentile(I_ens_r, 97.5, axis=0)

print(f"Ensemble β:  mean={beta_ens.mean():.3f}  std={beta_ens.std():.3f}  "
      f"(true {beta_true})")
print(f"Ensemble γ:  mean={gamma_ens.mean():.3f}  std={gamma_ens.std():.3f}  "
      f"(true {gamma_true})")
Ensemble β:  mean=0.295  std=0.000  (true 0.3)
Ensemble γ:  mean=0.098  std=0.000  (true 0.1)
# ── HMC over (log β, log γ) via differentiable Euler ODE ────────────────────
# Uses a differentiable Euler integrator as the likelihood — no PINN surrogate
# needed, no extra dependencies.  Gradients flow directly through the ODE steps.

sigma_obs = torch.tensor(20.0 / N)   # observation noise in scaled units

def _sir_euler(lb, lg, n_steps=500):
    """
    Euler integration of the SIR model in scaled time tau = t / T_max.
    Returns I(t_obs_s) — fully differentiable w.r.t. lb and lg.
    """
    beta_s  = torch.exp(lb) * T_max
    gamma_s = torch.exp(lg) * T_max
    dt      = 1.0 / n_steps
    S, I, R = y0_s[0].clone(), y0_s[1].clone(), y0_s[2].clone()
    I_grid  = [I]
    for _ in range(n_steps):
        dS = -beta_s * S * I
        dI =  beta_s * S * I - gamma_s * I
        dR =  gamma_s * I
        S, I, R = S + dt * dS, I + dt * dI, R + dt * dR
        I_grid.append(I)
    I_grid  = torch.stack(I_grid)                   # (n_steps+1,)
    obs_idx = (t_obs_s * n_steps).long().clamp(0, n_steps)
    return I_grid[obs_idx]


def _log_post_hmc(lb, lg):
    I_hat   = _sir_euler(lb, lg)
    log_lik = -0.5 * ((I_hat - I_obs_s)**2 / sigma_obs**2).sum()
    log_pri = -0.5 * (((lb - prior_mu_lb)**2 + (lg - prior_mu_lg)**2)
                      / prior_sig**2)
    return log_lik + log_pri


def _grad_U(lb, lg):
    lb_ = lb.detach().requires_grad_(True)
    lg_ = lg.detach().requires_grad_(True)
    (-_log_post_hmc(lb_, lg_)).backward()
    return lb_.grad.detach(), lg_.grad.detach()


def _leapfrog_hmc(lb, lg, p_lb, p_lg, step, L):
    lb, lg = lb.detach().clone(), lg.detach().clone()
    g_lb, g_lg = _grad_U(lb, lg)
    p_lb = p_lb - 0.5 * step * g_lb
    p_lg = p_lg - 0.5 * step * g_lg
    for i in range(L):
        lb = lb + step * p_lb
        lg = lg + step * p_lg
        g_lb, g_lg = _grad_U(lb, lg)
        factor = 0.5 if i == L - 1 else 1.0
        p_lb = p_lb - factor * step * g_lb
        p_lg = p_lg - factor * step * g_lg
    return lb, lg, p_lb, p_lg


torch.manual_seed(0);  np.random.seed(0)
lb_c   = torch.tensor(float(np.log(0.3)))
lg_c   = torch.tensor(float(np.log(0.1)))
curr_lp = _log_post_hmc(lb_c, lg_c).item()

n_warm_h, n_samp_h = 200, 500
step_h, L_h = 0.05, 15

hmc_lb, hmc_lg, n_acc = [], [], 0

for i in range(n_warm_h + n_samp_h):
    p_lb = torch.randn(());  p_lg = torch.randn(())
    H_c  = -curr_lp + 0.5 * float(p_lb**2 + p_lg**2)
    lb_p, lg_p, p_lb_p, p_lg_p = _leapfrog_hmc(lb_c, lg_c, p_lb, p_lg, step_h, L_h)
    prop_lp = _log_post_hmc(lb_p, lg_p).item()
    H_p     = -prop_lp + 0.5 * float(p_lb_p**2 + p_lg_p**2)
    if np.log(np.random.uniform() + 1e-15) < H_c - H_p:
        lb_c, lg_c, curr_lp = lb_p, lg_p, prop_lp
        n_acc += 1
    if i >= n_warm_h:
        hmc_lb.append(lb_c.item());  hmc_lg.append(lg_c.item())

beta_hmc  = np.exp(hmc_lb);  gamma_hmc = np.exp(hmc_lg)
print(f"HMC acceptance rate: {n_acc / (n_warm_h + n_samp_h):.2f}")
print(f"HMC β:  mean={beta_hmc.mean():.3f}  std={beta_hmc.std():.3f}  (true {beta_true})")
print(f"HMC γ:  mean={gamma_hmc.mean():.3f}  std={gamma_hmc.std():.3f}  (true {gamma_true})")

traj_hmc = np.zeros((n_samp_h, len(t_viz_np), 3))
for i, (b, g) in enumerate(zip(beta_hmc, gamma_hmc)):
    traj_hmc[i] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))

np.random.seed(42)
I_hmc_noisy = traj_hmc[:, :, 1] + np.random.normal(0, 20, traj_hmc[:, :, 1].shape)
I_hmc_noisy = np.clip(I_hmc_noisy, 0, N)
I_lo_h  = np.percentile(I_hmc_noisy, 2.5,  axis=0)
I_med_h = np.percentile(I_hmc_noisy, 50,   axis=0)
I_hi_h  = np.percentile(I_hmc_noisy, 97.5, axis=0)
HMC acceptance rate: 0.00
HMC β:  mean=0.300  std=0.000  (true 0.3)
HMC γ:  mean=0.100  std=0.000  (true 0.1)
# ── Full-rank VI posterior predictive ────────────────────────────────────────
with torch.no_grad():
    fr_lb_samp, fr_lg_samp = fr_model.sample_log_params(M)
    fr_beta_samp  = torch.exp(fr_lb_samp).numpy()
    fr_gamma_samp = torch.exp(fr_lg_samp).numpy()

fr_traj = np.zeros((M, len(t_viz_np), 3))
for m, (b, g) in enumerate(zip(fr_beta_samp, fr_gamma_samp)):
    fr_traj[m] = scipy_odeint(sir_rhs, y0, t_viz_np, args=(float(b), float(g), N))

np.random.seed(42)
fr_I_noisy = fr_traj[:, :, 1] + np.random.normal(0, 20, fr_traj[:, :, 1].shape)
fr_I_noisy = np.clip(fr_I_noisy, 0, N)
I_lo_fr  = np.percentile(fr_I_noisy, 2.5,  axis=0)
I_med_fr = np.percentile(fr_I_noisy, 50,   axis=0)
I_hi_fr  = np.percentile(fr_I_noisy, 97.5, axis=0)
fig, axes = plt.subplots(1, 4, figsize=(12, 3.5))

for ax, lo, med, hi, title in [
    (axes[0], I_lo,    I_med,    I_hi,    "VI mean-field (M=500)"),
    (axes[1], I_lo_fr, I_med_fr, I_hi_fr, "VI full-rank (M=500)"),
    (axes[2], I_lo_h,  I_med_h,  I_hi_h,  f"HMC (n={n_samp_h})"),
    (axes[3], I_lo_e,  I_med_e,  I_hi_e,  f"Ensemble PINN (n={n_ens})"),
]:
    ax.fill_between(t_viz_np, lo, hi, alpha=0.25, color=C_PI, label="95% PI (+ obs noise)")
    ax.plot(t_viz_np, med,            color=C_POST, lw=1.5, label="Median")
    ax.plot(t_viz_np, sol_viz[:, 1],  color=C_I,    lw=1.5, ls="--",
            alpha=0.7, label="True I(t)")
    ax.scatter(t_obs, I_obs, color=C_OBS, s=12, zorder=5, label="Observed")
    ax.set(xlabel="Day", ylabel="Infectious (I)", title=title)
    ax.legend()

plt.tight_layout()
plt.show()

Note

Comparing the four methods

Mean-field VI assumes \(q(\beta, \gamma) = q(\beta) \cdot q(\gamma)\) — independent marginals. But \(I(t)\) data constrain \(\beta/\gamma \approx R_0\) more tightly than either parameter alone, creating a strong negative correlation in the joint posterior; mean-field inflates marginal variances to compensate, widening the predictive envelope. Full-rank VI adds one off-diagonal Cholesky parameter (\(L_{21}\)) to capture this correlation directly — its predictive interval narrows substantially and the median tracks the true curve more closely. HMC explores the posterior ridge exactly via gradient-driven leapfrog proposals and serves as the gold-standard reference. The Ensemble PINN (4) bypasses the posterior entirely — each PINN converges to a point near the true parameters, so the spread reflects optimisation variability rather than principled Bayesian uncertainty.

7.3 Marginal posteriors

R0_samp = beta_samp / gamma_samp

fig, axes = plt.subplots(1, 3, figsize=(7, 3))

for ax, samples, true_val, label, color in [
    (axes[0], beta_samp,  beta_true,             r"$\beta$ (day⁻¹)",    C_I),
    (axes[1], gamma_samp, gamma_true,             r"$\gamma$ (day⁻¹)",   C_R),
    (axes[2], R0_samp,    beta_true/gamma_true,   r"$R_0 = \beta/\gamma$", C_POST),
]:
    counts, _, _ = ax.hist(samples, bins=40, color=color,
                           alpha=0.7, edgecolor="white")
    ax.axvline(true_val,          color="black", lw=2, ls="--",
               label=f"True = {true_val:.2f}")
    ax.axvline(np.mean(samples),  color=color,   lw=2,
               label=f"Post. mean = {np.mean(samples):.2f}")
    ax.set(xlabel=label, ylabel="Count")
    ax.legend()

axes[0].set_title(r"Posterior: $\beta$")
axes[1].set_title(r"Posterior: $\gamma$")
axes[2].set_title(r"Posterior: $R_0$")
plt.suptitle("Bayesian PINN — marginal posteriors", y=1.02)
plt.tight_layout()
plt.show()

print(f"β  — true: {beta_true:.3f}  | "
      f"post. mean: {beta_samp.mean():.3f}  "
      f"[{np.percentile(beta_samp,2.5):.3f}, {np.percentile(beta_samp,97.5):.3f}]")
print(f"γ  — true: {gamma_true:.3f}  | "
      f"post. mean: {gamma_samp.mean():.3f}  "
      f"[{np.percentile(gamma_samp,2.5):.3f}, {np.percentile(gamma_samp,97.5):.3f}]")
print(f"R₀ — true: {beta_true/gamma_true:.2f}   | "
      f"post. mean: {R0_samp.mean():.2f}   "
      f"[{np.percentile(R0_samp,2.5):.2f}, {np.percentile(R0_samp,97.5):.2f}]")

β  — true: 0.300  | post. mean: 0.341  [0.126, 0.795]
γ  — true: 0.100  | post. mean: 0.117  [0.034, 0.286]
R₀ — true: 3.00   | post. mean: 3.93   [0.73, 12.46]

8 Connection to policy decisions

The posterior on \(R_0 = \beta / \gamma\) connects directly to the pre-emptive vs. reactive OCV stockpile decision. The critical threshold from the single-subunit model is:

\[ p_\text{crit} = \frac{r}{r + R(\nu - r)} \]

where \(p\) is the per-district outbreak probability, \(r\) is reactive effectiveness, \(\nu\) is pre-emptive effectiveness, and \(R\) is the cost ratio. Pre-emptive vaccination is preferred when \(p > p_\text{crit}\).

The Bayesian PINN contributes the distribution of \(R_0\) inferred from surveillance data. Where the posterior on \(R_0\) is wide — as in early-epidemic or low-surveillance settings — the uncertainty in \(p_\text{crit}\) is large, and the value of additional risk intelligence (a wider or better surveillance system) is correspondingly high.

threshold_R0 = 2.5
p_above      = float((R0_samp > threshold_R0).mean())

fig, ax = plt.subplots(figsize=(5, 3.5))
counts, _, _ = ax.hist(R0_samp, bins=40, color=C_POST, alpha=0.7, edgecolor="white")
ylim_top = counts.max() * 1.15

ax.axvline(beta_true / gamma_true, color=C_I,    lw=2,
           label=f"True R₀ = {beta_true/gamma_true:.1f}")
ax.axvline(threshold_R0,           color="black", lw=2, ls="--",
           label=f"R₀ threshold = {threshold_R0}")
ax.fill_betweenx([0, ylim_top], threshold_R0, R0_samp.max() + 0.3,
                 alpha=0.15, color="red",
                 label=f"P(R₀ > {threshold_R0}) = {p_above:.2f}")
ax.set_ylim(0, ylim_top)
ax.set(xlabel=r"$R_0$", ylabel="Count",
       title=r"Posterior on $R_0$ — decision-relevant probability mass")
ax.legend()
plt.tight_layout()
plt.show()

Important

The policy payoff

A point-estimate PINN reports: “\(R_0 = 3.0\), therefore pre-emptive is preferred.” A Bayesian PINN reports: “\(P(R_0 > 2.5) = 0.78\), so there is a 22% chance the outbreak is mild enough that reactive vaccination would have sufficed.”

These are qualitatively different inputs for a decision-maker allocating a fixed ICG stockpile across competing requests. The distribution, not the point, is what enters the expected-loss calculation.


9 What the variational posterior does not capture

Limitation Impact Remedy
Mean-field: \(q(\beta,\gamma) = q(\beta)\cdot q(\gamma)\) Misses \(\beta\)-\(\gamma\) correlation Full-rank VI with Cholesky covariance — implemented above
Point-estimate network weights \(\theta\) No uncertainty from network architecture Full B-PINN (1): weight distributions via BNN
Log-normal variational family Poor fit if posterior is multimodal Normalizing flows (5)
\(K = 2\) MC samples per step Noisy ELBO gradient Increase \(K\); use importance-weighted ELBO

The point-estimate network is the most consequential simplification. In the full B-PINN of Yang et al. (1), the network weights also carry distributions, giving a richer notion of epistemic uncertainty (uncertainty from limited data) on top of the aleatoric uncertainty (irreducible noise) captured here.


10 Extensions

  • Normalizing flows (5): a sequence of invertible transforms on a Gaussian base → arbitrarily complex posterior geometry, no mean-field assumption.
  • MCMC via Pyro/NumPyro: replace VI with exact HMC/NUTS; use the trained PINN as a differentiable, physics-consistent likelihood. VI initialises the sampler near the posterior mode to accelerate mixing.
  • Sequential (online) inference: update \(q(\beta, \gamma)\) as new surveillance data arrive during an ongoing outbreak — the posterior from week \(t\) becomes the prior for week \(t+1\), yielding a real-time Bayesian nowcast.
  • SEIR/SEIRS extension: the conditional network architecture generalises directly; add incubation rate \(\sigma\) as a third variational parameter and expand the network input to \((t,\,\log\beta,\,\log\gamma,\,\log\sigma)\).

11 References

1.
Yang L, Meng X, Karniadakis GE. B-PINNs: Bayesian physics-informed neural networks for forward and inverse PDE problems with noisy data. Journal of Computational Physics. 2021;425:109913. doi:10.1016/j.jcp.2020.109913
2.
Kingma DP, Welling M. Auto-encoding variational Bayes. In: International conference on learning representations [Internet]. 2014. Available from: https://arxiv.org/abs/1312.6114
3.
Blei DM, Kucukelbir A, McAuliffe JD. Variational inference: A review for statisticians. Journal of the American Statistical Association. 2017;112(518):859–77. doi:10.1080/01621459.2017.1285773
4.
Lakshminarayanan B, Pritzel A, Blundell C. Simple and scalable predictive uncertainty estimation using deep ensembles. In: Advances in neural information processing systems [Internet]. 2017. Available from: https://arxiv.org/abs/1612.01474
5.
Rezende DJ, Mohamed S. Variational inference with normalizing flows. In: Proceedings of the 32nd international conference on machine learning [Internet]. 2015. p. 1530–8. Available from: https://arxiv.org/abs/1505.05770
Python     3.12.10
torch      2.11.0+cpu
scipy      1.17.1
matplotlib 3.10.8