Outbreak Simulation

Author

Jong-Hoon Kim

Published

October 13, 2025

1 Setup

Code
library(tidyverse)
library(rstan)
library(bayesplot)
library(GGally)
library(splines)

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

theme_set(bayesplot::theme_default())
set.seed(2025)

rmse <- function(x,y) sqrt(mean((x-y)^2))
mae  <- function(x,y) mean(abs(x-y))

# NegBin2 sampler with mean mu and overdispersion phi (size = 1/phi)
rnbinom_mu_phi <- function(n, mu, phi) {
  size <- 1/phi
  rnbinom(n, mu = mu, size = size)
}

2 Generate Fake Historical Outbreaks

Code
R <- 10        # number of regions
C <- 30        # total number of countries (across all regions)

# OPTION A: Balanced (≈ C/R countries per region)
build_map_balanced <- function(R, C) {
  regs <- paste0("R", seq_len(R))
  # split C countries across R regions as evenly as possible
  sizes <- rep(floor(C / R), R)
  sizes[seq_len(C %% R)] <- sizes[seq_len(C %% R)] + 1
  tibble(
    country = paste0("C", seq_len(C)),
    region  = rep(regs, times = sizes)
  )
}

# OPTION B: Unbalanced (random allocation using Dirichlet weights)
build_map_unbalanced <- function(R, C, conc = 1) {
  regs <- paste0("R", seq_len(R))
  w    <- rgamma(R, shape = conc, rate = 1) 
  w <- w / sum(w)
  sizes <- as.integer(round(w * C))
  # fix rounding so sizes sum to C
  while(sum(sizes) != C) {
    i <- sample(seq_len(R), 1)
    sizes[i] <- sizes[i] + sign(C - sum(sizes))
  }
  tibble(
    country = paste0("C", seq_len(C)),
    region  = rep(regs, times = sizes)
  )
}

# Choose one:
# country_region_map <- build_map_unbalanced(R, C, conc = 0.7)
country_region_map <- build_map_balanced(R, C)
# sanity checks
stopifnot(
  nrow(country_region_map) == C,
  all(!duplicated(country_region_map$country))
)
count(country_region_map, region, name = "n_countries")
# A tibble: 10 × 2
   region n_countries
   <chr>        <int>
 1 R1               3
 2 R10              3
 3 R2               3
 4 R3               3
 5 R4               3
 6 R5               3
 7 R6               3
 8 R7               3
 9 R8               3
10 R9               3

Observed meta-parameters per outbreak \(n = 1,.., N\)

\(N_n\): population size; t_n: time to peak (weeks) \(D_n\): duration (weeks) \(i_n\): peak weekly incidence per person in \((0,1)\) \(y_n\): year \(r_n \in {1,...,R}\): region \(c_n \in {1,...,C}\): country (nested in region)

Modeling scales (logs and logit)

\[ \begin{aligned} \tilde{N}_n &= \log N_n,\\ \tilde{t}_n &= \log t_n,\\ \tilde{D}_n &= \log D_n,\\ \tilde{i}_n &= \operatorname{logit}(i_n) \;=\; \log\!\frac{i_n}{1-i_n}. \end{aligned} \]

Code
set.seed(2025)

n_outbreaks <- 1000
years <- 2000:2020

# sample countries for each outbreak, then derive region from the map
country <- factor(
  sample(country_region_map$country, n_outbreaks, replace = TRUE),
  levels = country_region_map$country
)
# look up region by country (nested)
region <- factor(
  country_region_map$region[ match(country, country_region_map$country) ],
  levels = unique(country_region_map$region)
)

year <- sample(years, n_outbreaks, replace = TRUE)

# ----- Shared RE across outcomes (logN, logt, logD, logit_i) -----
Sigma_re <- matrix(c(
  0.20, 0.05, 0.04, 0.03,
  0.05, 0.25, 0.06, 0.04,
  0.04, 0.06, 0.20, 0.05,
  0.03, 0.04, 0.05, 0.15
), 4, 4)
stopifnot(all(eigen(Sigma_re, symmetric = TRUE, only.values = TRUE)$values > 0))
Lre <- chol(Sigma_re)

R <- nlevels(region)
C <- nlevels(country)
re_region  <- matrix(NA_real_, 4, R)
re_country <- matrix(NA_real_, 4, C)
for (r in 1:R) re_region[,  r] <- drop(t(Lre) %*% rnorm(4))
for (c in 1:C) re_country[, c] <- drop(t(Lre) %*% rnorm(4))

f_year <- function(y) 0.2 * sin((y - 2000) / 20 * 2*pi)

# integer indices for fast lookup
r_id <- as.integer(region)
c_id <- as.integer(country)

# ----- Ordered generative mechanism -----
logN_mu <- 12 + 0.8 * f_year(year)
logN <- logN_mu + re_region[1, r_id] + re_country[1, c_id] + rnorm(n_outbreaks, 0, 0.35)

logt_mu <- 1.7 + 0.15 * (logN - mean(logN)) + 0.6 * f_year(year)
logt <- logt_mu + re_region[2, r_id] + re_country[2, c_id] + rnorm(n_outbreaks, 0, 0.25)

logD_mu <- 2.0 + 0.10 * (logN - mean(logN)) + 0.40 * (logt - mean(logt)) + 0.5 * f_year(year)
logD <- logD_mu + re_region[3, r_id] + re_country[3, c_id] + rnorm(n_outbreaks, 0, 0.25)

logit_i_mu <- -6.0 -
  0.20 * (logN - mean(logN)) +
  0.25 * (logt - mean(logt)) -
  0.10 * (logD - mean(logD)) +
  0.4  * f_year(year) +
  0.12 * (logt - mean(logt)) * (logD - mean(logD))
logit_i <- logit_i_mu + re_region[4, r_id] + re_country[4, c_id] + 
  rnorm(n_outbreaks, 0, 0.35)

# ----- Back-transform -----
N_pop             <- pmax(5e3, round(exp(logN)))
t_peak_weeks      <- pmax(1, round(exp(logt)))
duration_weeks    <- pmax(t_peak_weeks + 1, round(exp(logD)))
i_peak_per_person <- plogis(logit_i)

outbk <- tibble(year, region, country, N_pop, t_peak_weeks, duration_weeks, i_peak_per_person)

# quick checks (optional)
outbk %>%
  pivot_longer(c(N_pop, t_peak_weeks, duration_weeks, i_peak_per_person)) %>%
  ggplot(aes(value)) + geom_histogram(bins = 30, fill = "grey70") +
  facet_wrap(~name, scales = "free")

Code
# verify nesting: every country maps to exactly one region
outbk %>% distinct(country, region) %>% count(country) %>% summarise(all(n == 1))
# A tibble: 1 × 1
  `all(n == 1)`
  <lgl>        
1 TRUE         

Quick sanity checks

Code
# 3.1 basic marginals
outbk |>
  pivot_longer(c(N_pop, t_peak_weeks, duration_weeks, i_peak_per_person)) |>
  ggplot(aes(value)) + 
  geom_histogram(bins = 30, fill = "grey70") +
  facet_wrap(~name, scales = "free") + 
  labs(title = "Marginals (nested generator)")

Code
# 3.2 transformed pairwise structure (clamp i ∈ (0,1) before qlogis)
GGally::ggpairs(
  outbk |>
    transmute(
      logN   = log(N_pop),
      logt   = log(t_peak_weeks),
      logD   = log(duration_weeks),
      logit_i = qlogis(pmin(pmax(i_peak_per_person, 1e-6), 1 - 1e-6))
    ),
  title = "Transformed pairwise (nested)"
)

Code
# 3.3 expected directionality (quick numerical checks)
with(outbk, {
  cat("cor(logN, logt)       =", 
      cor(log(N_pop), log(t_peak_weeks)), "\n")
  cat("cor(logN, logD)       =", 
      cor(log(N_pop), log(duration_weeks)), "\n")
  cat("cor(logN, logit(i))   =", 
      cor(log(N_pop), qlogis(pmin(pmax(i_peak_per_person,1e-6),1-1e-6))), "\n")
  cat("cor(logt, logD)       =", 
      cor(log(t_peak_weeks), log(duration_weeks)), "\n")
})
cor(logN, logt)       = 0.3450293 
cor(logN, logD)       = 0.5685984 
cor(logN, logit(i))   = 0.07914388 
cor(logt, logD)       = 0.7627904 

3 Design Matrices (splines + interactions)

Code
## ---------- Safe helpers ----------
# pick df safely given the number of unique x values
safe_df <- function(x, target_df, min_df = 2) {
  nu <- length(unique(x))
  max(min_df, min(target_df, nu - 1))
}

# B-spline basis WITHOUT intercept, df chosen safely
bs_noi_safe <- function(x, target_df) {
  df_use <- safe_df(x, target_df)
  as.matrix(bs(x, df = df_use, intercept = FALSE))
}

# Safe column-wise standardization: never returns NaN (0-sd -> 1)
safe_scale <- function(M) {
  M <- as.matrix(M)
  mu  <- suppressWarnings(colMeans(M))
  sdv <- suppressWarnings(apply(M, 2, sd))
  sdv[!is.finite(sdv) | sdv == 0] <- 1
  sweep(sweep(M, 2, mu, FUN = "-"), 2, sdv, FUN = "/")
}

# Khatri–Rao (column-wise Kronecker) tensor product of two bases
tensor_kr <- function(B1, B2) {
  B1 <- as.matrix(B1); B2 <- as.matrix(B2)
  out <- vector("list", ncol(B1))
  for (i in seq_len(ncol(B1))) out[[i]] <- B2 * B1[, i]
  do.call(cbind, out)
}

# Drop columns that are constant or have any non-finite values
drop_bad_cols <- function(M) {
  M <- as.matrix(M)
  keep <- apply(M, 2, function(col) {
    ok <- all(is.finite(col))
    v  <- var(col)
    ok && is.finite(v) && v > 0
  })
  if (!any(keep)) 
    stop("All columns were dropped as 'bad' in a design matrix. Reduce df or check inputs.")
  M[, keep, drop = FALSE]
}

# Zero-variance checker that treats NA as 0-variance (to be safe)
has_zero_var <- function(M) {
  s <- suppressWarnings(apply(as.matrix(M), 2, sd))
  s[is.na(s)] <- 0
  any(s == 0)
}

df <- outbk %>% 
   transmute(
      year    = year,
      region  = region, 
      country = country,
      logN    = log(N_pop),
      logt    = log(t_peak_weeks),
      logD    = log(duration_weeks),
      logit_i = qlogis(pmin(pmax(i_peak_per_person, 1e-6), 1 - 1e-6))
  )
 ## ---------- Build design matrices ----------
# Expect df tibble is already defined from your fake data step:
# df has: year (numeric), region (factor), country (factor),
#         logN, logt, logD, logit_i (numeric)
stopifnot(all(c("year","region","country","logN","logt","logD","logit_i") 
              %in% names(df)))

# Center/scale the continuous predictors used as raw columns
x_logN <- as.numeric(scale(df$logN))
x_logt <- as.numeric(scale(df$logt))
x_logD <- as.numeric(scale(df$logD))

# --- p(logN) : splines of year only ---
B_year_N <- bs_noi_safe(df$year, target_df = 8)
B_year_N <- drop_bad_cols(B_year_N)
XN <- cbind(1, safe_scale(B_year_N))

# --- p(logt | logN, year) ---
B_year_t <- bs_noi_safe(df$year, target_df = 8)
B_year_t <- drop_bad_cols(B_year_t)
Xt <- cbind(1, safe_scale(B_year_t), x_logN)

# --- p(logD | logN, logt, year) ---
B_year_D <- bs_noi_safe(df$year, target_df = 8)
B_year_D <- drop_bad_cols(B_year_D)
XD <- cbind(1, safe_scale(B_year_D), x_logN, x_logt)

# --- p(logit(i) | logN, logt, logD, year) + tensor(logt, logD) ---
B_year_I <- bs_noi_safe(df$year, target_df = 8)
B_logt   <- bs_noi_safe(df$logt, target_df = 6)
B_logD   <- bs_noi_safe(df$logD, target_df = 6)

# Clean bases before tensor/scaling
B_year_I <- drop_bad_cols(B_year_I)
B_logt   <- drop_bad_cols(B_logt)
B_logD   <- drop_bad_cols(B_logD)

tensor_tD <- tensor_kr(B_logt, B_logD)
tensor_tD <- drop_bad_cols(tensor_tD)

XI <- cbind(
  1,
  safe_scale(B_year_I),
  x_logN,
  x_logt,
  x_logD,                 # vector (scaled logD)
  safe_scale(tensor_tD)
)

## ---------- Diagnostics before Stan ----------
cat("\nDesign matrix dimensions:\n")

Design matrix dimensions:
Code
cat("  XN:", dim(XN), "\n")
  XN: 1000 9 
Code
cat("  Xt:", dim(Xt), "\n")
  Xt: 1000 10 
Code
cat("  XD:", dim(XD), "\n")
  XD: 1000 11 
Code
cat("  XI:", dim(XI), "\n")
  XI: 1000 47 
Code
stopifnot(
  all(is.finite(XN)), all(is.finite(Xt)), all(is.finite(XD)), all(is.finite(XI)),
  nrow(XN) == nrow(df), nrow(Xt) == nrow(df), nrow(XD) == nrow(df), nrow(XI) == nrow(df)
)


# Replace the old helper with this:
has_zero_var <- function(M, intercept = TRUE) {
  M <- as.matrix(M)
  if (ncol(M) == 0) return(FALSE)
  cols <- if (intercept && ncol(M) > 1) 2:ncol(M) else 1:ncol(M)
  s <- suppressWarnings(apply(M[, cols, drop = FALSE], 2, sd))
  s[is.na(s)] <- 0
  any(s == 0)
}

# Optional: a verbose variant to see which columns (excluding intercept) are constant
report_zero_var <- function(M, name = "X", intercept = TRUE) {
  M <- as.matrix(M)
  cols <- if (intercept && ncol(M) > 1) 2:ncol(M) else 1:ncol(M)
  s <- suppressWarnings(apply(M[, cols, drop = FALSE], 2, sd))
  s[is.na(s)] <- 0
  bad <- which(s == 0)
  if (length(bad)) {
    message(sprintf("%s zero-variance (excluding intercept): %s", name,
                    paste(cols[bad], collapse = ", ")))
  } else {
    message(sprintf("%s: no zero-variance columns (excluding intercept).", name))
  }
}

# Re-run the checks (these should pass now if only the intercept was constant)
if (has_zero_var(XN, intercept = TRUE)) 
  stop("XN has zero-variance columns (excluding intercept).")
if (has_zero_var(Xt, intercept = TRUE)) 
  stop("Xt has zero-variance columns (excluding intercept).")
if (has_zero_var(XD, intercept = TRUE)) 
  stop("XD has zero-variance columns (excluding intercept).")
if (has_zero_var(XI, intercept = TRUE)) 
  stop("XI has zero-variance columns (excluding intercept).")

# (Optional) print details
report_zero_var(XN, "XN", TRUE)
report_zero_var(Xt, "Xt", TRUE)
report_zero_var(XD, "XD", TRUE)
report_zero_var(XI, "XI", TRUE)

# IDs & sizes
N <- nrow(df)
R <- nlevels(df$region)
C <- nlevels(df$country)
region_id  <- as.integer(df$region)
country_id <- as.integer(df$country)

pN <- ncol(XN); pt <- ncol(Xt); pD <- ncol(XD); pI <- ncol(XI)

stan_data <- list(
  N = N, R = R, C = C,
  region_id = region_id, country_id = country_id,
  y_logN = df$logN, y_logt = df$logt, y_logD = df$logD, y_logiti = df$logit_i,
  pN = pN, XN = XN,
  pt = pt, Xt = Xt,
  pD = pD, XD = XD,
  pI = pI, XI = XI
)

stopifnot(all(is.finite(as.matrix(XN))),
          all(is.finite(as.matrix(Xt))),
          all(is.finite(as.matrix(XD))),
          all(is.finite(as.matrix(XI))))
stopifnot(!any(is.na(df$logN)), !any(is.na(df$logt)), 
          !any(is.na(df$logD)), !any(is.na(df$logit_i)))
# no zero-variance columns
sapply(list(XN=XN,Xt=Xt,XD=XD,XI=XI), \(M) any(apply(M,2,sd)==0))
  XN   Xt   XD   XI 
TRUE TRUE TRUE TRUE 

4 Stan model and fit

Code
## ---------- Fit the simpler diagonal-RE Stan model ----------
stan_code <- "
data {
  int<lower=1> N;
  int<lower=1> R;
  int<lower=1> C;
  int<lower=1, upper=R> region_id[N];
  int<lower=1, upper=C> country_id[N];

  vector[N] y_logN;
  vector[N] y_logt;
  vector[N] y_logD;
  vector[N] y_logiti;

  int<lower=1> pN; matrix[N, pN] XN;
  int<lower=1> pt; matrix[N, pt] Xt;
  int<lower=1> pD; matrix[N, pD] XD;
  int<lower=1> pI; matrix[N, pI] XI;
}
parameters {
  vector[pN] beta_N;
  vector[pt] beta_t;
  vector[pD] beta_D;
  vector[pI] beta_I;

  real<lower=0> sigma_N;
  real<lower=0> sigma_t;
  real<lower=0> sigma_D;
  real<lower=0> sigma_I;

  vector<lower=0>[4] tau_reg;
  cholesky_factor_corr[4] Lcorr_reg;
  matrix[4, R] z_reg;

  vector<lower=0>[4] tau_cty;
  cholesky_factor_corr[4] Lcorr_cty;
  matrix[4, C] z_cty;
}
transformed parameters {
  matrix[4, R] u_reg;
  matrix[4, C] u_cty;

  matrix[4,4] L_reg = diag_pre_multiply(tau_reg, Lcorr_reg);
  matrix[4,4] L_cty = diag_pre_multiply(tau_cty, Lcorr_cty);

  u_reg = L_reg * z_reg;
  u_cty = L_cty * z_cty;
}
model {
  beta_N ~ normal(0, 1.5);
  beta_t ~ normal(0, 1.5);
  beta_D ~ normal(0, 1.5);
  beta_I ~ normal(0, 1.5);

  sigma_N ~ normal(0, 1);
  sigma_t ~ normal(0, 1);
  sigma_D ~ normal(0, 1);
  sigma_I ~ normal(0, 1);

  tau_reg ~ normal(0, 1);
  tau_cty ~ normal(0, 1);
  Lcorr_reg ~ lkj_corr_cholesky(2);
  Lcorr_cty ~ lkj_corr_cholesky(2);

  to_vector(z_reg) ~ normal(0, 1);
  to_vector(z_cty) ~ normal(0, 1);

  for (n in 1:N) {
    int r = region_id[n];
    int c = country_id[n];
    real aN = u_reg[1, r] + u_cty[1, c];
    real at = u_reg[2, r] + u_cty[2, c];
    real aD = u_reg[3, r] + u_cty[3, c];
    real aI = u_reg[4, r] + u_cty[4, c];

    y_logN[n]   ~ normal( XN[n] * beta_N + aN, sigma_N );
    y_logt[n]   ~ normal( Xt[n] * beta_t + at, sigma_t );
    y_logD[n]   ~ normal( XD[n] * beta_D + aD, sigma_D );
    y_logiti[n] ~ normal( XI[n] * beta_I + aI, sigma_I );
  }
}
generated quantities {
  vector[N] y_logN_rep;
  vector[N] y_logt_rep;
  vector[N] y_logD_rep;
  vector[N] y_logiti_rep;

  corr_matrix[4] Omega_reg = multiply_lower_tri_self_transpose(Lcorr_reg);
  corr_matrix[4] Omega_cty = multiply_lower_tri_self_transpose(Lcorr_cty);

  for (n in 1:N) {
    int r = region_id[n];
    int c = country_id[n];
    real aN = u_reg[1, r] + u_cty[1, c];
    real at = u_reg[2, r] + u_cty[2, c];
    real aD = u_reg[3, r] + u_cty[3, c];
    real aI = u_reg[4, r] + u_cty[4, c];

    y_logN_rep[n]   = normal_rng( XN[n] * beta_N + aN, sigma_N );
    y_logt_rep[n]   = normal_rng( Xt[n] * beta_t + at, sigma_t );
    y_logD_rep[n]   = normal_rng( XD[n] * beta_D + aD, sigma_D );
    y_logiti_rep[n] = normal_rng( XI[n] * beta_I + aI, sigma_I );
  }
}
"
writeLines(stan_code, "ordered_conditional_plain_diagRE.stan")

stopifnot(file.exists("ordered_conditional_plain_diagRE.stan"))

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

set.seed(2025)
mod_diag <- stan_model("ordered_conditional_plain_diagRE.stan")
fit <- sampling(
  mod_diag, data = stan_data,
  chains = 4, iter = 3000, warmup = 1500, thin = 1, init = 0,
  control = list(adapt_delta = 0.99, max_treedepth = 14, stepsize = 0.5),
  refresh = 200
)

print(fit, pars = c("sigma_N","sigma_t","sigma_D","sigma_I","tau_reg","tau_cty"))
rstan::check_hmc_diagnostics(fit)
saveRDS(fit, "fit_20251013.rds")

5 Posterior Predictive Checks

Code
# --- VISUALIZING MODEL PREDICTED VS DATA (PPCs & CALIBRATION) ---
fit <- read_rds("C:/Users/jonghoon.kim/Workspace/myblog/fit_20251013.rds")

# 0) sanity: fit exists and has samples?
stopifnot(inherits(fit, "stanfit"))
draws_mat <- as.matrix(fit)
stopifnot(nrow(draws_mat) > 0)

# 1) robust y_rep extractors (works whether rstan::extract or as.matrix route is needed)
grab_yrep <- function(fit, base, Nobs) {
  # try extract first
  ex <- try(rstan::extract(fit, pars = base, permuted = TRUE), silent = TRUE)
  if (!inherits(ex, "try-error") && is.list(ex) && length(ex) == 1L) {
    Y <- ex[[1]]
    if (is.matrix(Y)) return(Y)
  }
  # fallback to as.matrix() column pattern
  dm <- as.matrix(fit)
  cols <- grep(paste0("^", base, "\\["), colnames(dm), value = TRUE)
  stopifnot(length(cols) == Nobs)
  dm[, cols, drop = FALSE]
}

Nobs <- nrow(df)
yrep_logN <- grab_yrep(fit, "y_logN_rep",   Nobs)
yrep_logt <- grab_yrep(fit, "y_logt_rep",   Nobs)
yrep_logD <- grab_yrep(fit, "y_logD_rep",   Nobs)
yrep_logi <- grab_yrep(fit, "y_logiti_rep", Nobs)

# thin draws for plots (keep matrix shape)
thin_idx <- seq_len(min(200, nrow(yrep_logN)))
Yreps <- list(
  logN   = yrep_logN[thin_idx, , drop=FALSE],
  logt   = yrep_logt[thin_idx, , drop=FALSE],
  logD   = yrep_logD[thin_idx, , drop=FALSE],
  logiti = yrep_logi[thin_idx, , drop=FALSE]
)

Yobs <- list(
  logN   = df$logN,
  logt   = df$logt,
  logD   = df$logD,
  logiti = df$logit_i
)

# 2) classic PPC overlays (model scales)
ppc_dens_overlay(y = Yobs$logN,   yrep = Yreps$logN)   + 
  ggtitle("PPC: logN (density)")

Code
ppc_dens_overlay(y = Yobs$logt,   yrep = Yreps$logt)   + 
  ggtitle("PPC: log t_peak (density)")

Code
ppc_dens_overlay(y = Yobs$logD,   yrep = Yreps$logD)   + 
  ggtitle("PPC: log duration (density)")

Code
ppc_dens_overlay(y = Yobs$logiti, yrep = Yreps$logiti) + 
  ggtitle("PPC: logit i_peak (density)")

Code
ppc_ecdf_overlay(y = Yobs$logN,   yrep = Yreps$logN)   + 
  ggtitle("PPC: logN (ECDF)")

Code
ppc_ecdf_overlay(y = Yobs$logt,   yrep = Yreps$logt)   + 
  ggtitle("PPC: log t_peak (ECDF)")

Code
ppc_ecdf_overlay(y = Yobs$logD,   yrep = Yreps$logD)   + 
  ggtitle("PPC: log duration (ECDF)")

Code
ppc_ecdf_overlay(y = Yobs$logiti, yrep = Yreps$logiti) + 
  ggtitle("PPC: logit i_peak (ECDF)")

Code
# 3) observed vs posterior predictive mean + 95% PI (model scales)
summ_one <- function(yrep, yobs, label) {
  tibble(
    obs = as.numeric(yobs),
    mean = colMeans(yrep),
    lo = apply(yrep, 2, quantile, 0.025),
    hi = apply(yrep, 2, quantile, 0.975),
    var = label,
    idx = seq_along(yobs)
  )
}

summ_df <- bind_rows(
  summ_one(Yreps$logN,   Yobs$logN,   "logN"),
  summ_one(Yreps$logt,   Yobs$logt,   "logt"),
  summ_one(Yreps$logD,   Yobs$logD,   "logD"),
  summ_one(Yreps$logiti, Yobs$logiti, "logit_i")
)

ggplot(summ_df, aes(x = obs, y = mean)) +
  geom_abline(slope = 1, intercept = 0, linetype = 2) +
  geom_errorbar(aes(ymin = lo, ymax = hi), alpha = 0.4) +
  geom_point(size = 0.7, alpha = 0.6) +
  facet_wrap(~var, scales = "free") +
  labs(title = "Observed vs Posterior Predictive Mean (95% PI) — model scales",
       x = "Observed", y = "Posterior predictive mean") +
  theme_minimal()

Code
# 4) coverage / calibration (model scales)
coverage_one <- function(yrep, yobs, probs = c(0.1, 0.5, 0.9)) {
  qs <- t(apply(yrep, 2, quantile, probs = probs))
  tibble(
    p10 = qs[,1], p50 = qs[,2], p90 = qs[,3],
    obs = yobs
  ) %>%
    summarise(
      cover80 = mean(obs >= p10 & obs <= p90),
      rmse = sqrt(mean((obs - p50)^2)),
      mae  = mean(abs(obs - p50))
    )
}
calib_tbl <- bind_rows(
  coverage_one(Yreps$logN,   Yobs$logN)   %>% mutate(var="logN"),
  coverage_one(Yreps$logt,   Yobs$logt)   %>% mutate(var="logt"),
  coverage_one(Yreps$logD,   Yobs$logD)   %>% mutate(var="logD"),
  coverage_one(Yreps$logiti, Yobs$logiti) %>% mutate(var="logit_i")
)
print(calib_tbl)
# A tibble: 4 × 4
  cover80  rmse   mae var    
    <dbl> <dbl> <dbl> <chr>  
1   0.818 0.332 0.262 logN   
2   0.809 0.256 0.203 logt   
3   0.812 0.197 0.149 logD   
4   0.832 0.355 0.284 logit_i
Code
# expect cover80 ≈ 0.8 if intervals are well-calibrated

# 5) natural-scale comparisons (histograms/densities)
# transform reps to natural scales
inv_logit <- function(x) 1/(1+exp(-x))

# pick a small set of draws to avoid memory blow-ups when back-transforming
bt_idx <- seq_len(min(100, nrow(yrep_logN)))

bt <- list(
  N_pop   = exp(yrep_logN[bt_idx, , drop=FALSE]),
  t_weeks = exp(yrep_logt[bt_idx, , drop=FALSE]),
  D_weeks = exp(yrep_logD[bt_idx, , drop=FALSE]),
  i_peak  = inv_logit(yrep_logi[bt_idx, , drop=FALSE])
)

obs_nat <- list(
  N_pop   = exp(df$logN),
  t_weeks = exp(df$logt),
  D_weeks = exp(df$logD),
  i_peak  = inv_logit(df$logit_i)
)

dens_long <- function(mat, label) {
  tibble(value = as.vector(t(mat)), draw = rep(seq_len(nrow(mat)), each = ncol(mat))) %>%
    mutate(source = "simulated", var = label)
}

obs_long  <- function(vec, label) tibble(value = vec, source = "observed", var = label)

nat_df <- bind_rows(
  dens_long(bt$N_pop,   "N_pop"),
  dens_long(bt$t_weeks, "t_peak_weeks"),
  dens_long(bt$D_weeks, "duration_weeks"),
  dens_long(bt$i_peak,  "i_peak_per_person"),
  obs_long(obs_nat$N_pop,   "N_pop"),
  obs_long(obs_nat$t_weeks, "t_peak_weeks"),
  obs_long(obs_nat$D_weeks, "duration_weeks"),
  obs_long(obs_nat$i_peak,  "i_peak_per_person")
)

ggplot(nat_df, aes(value, fill = source)) +
  geom_density(alpha = 0.35) +
  facet_wrap(~var, scales = "free") +
  labs(title = "Observed vs Posterior Predictive — natural scales", x = NULL, y = "density", fill = NULL) +
  theme_minimal()

Code
# 6) grouped PPCs (optional): by region or by coarse year bins
# useful to spot misfit in subgroups
df$year_bin <- cut(df$year, breaks = c(-Inf, 2005, 2010, 2015, Inf),
                   labels = c("≤2005","2006–2010","2011–2015","≥2016"), right = TRUE)

ppc_by_group <- function(yrep, yobs, group, title) {
  # compute groupwise observed vs predictive mean with intervals
  pred_mean <- colMeans(yrep)
  pred_lo   <- apply(yrep, 2, quantile, 0.1)
  pred_hi   <- apply(yrep, 2, quantile, 0.9)
  tibble(obs = yobs, mean = pred_mean, lo = pred_lo, hi = pred_hi, grp = group) %>%
    group_by(grp) %>%
    summarise(
      obs_mean  = mean(obs),
      pred_mean = mean(mean),
      pred_lo   = mean(lo),
      pred_hi   = mean(hi),
      .groups = "drop"
    ) %>%
    ggplot(aes(x = obs_mean, y = pred_mean)) +
    geom_abline(slope = 1, intercept = 0, linetype = 2) +
    geom_errorbar(aes(ymin = pred_lo, ymax = pred_hi), width = 0) +
    geom_point(size = 2) +
    labs(title = title, x = "Observed group mean", y = "Predicted group mean") +
    theme_minimal()
}

ppc_by_group(Yreps$logN,   Yobs$logN,   
             df$region,   "Group PPC: logN by region")

Code
ppc_by_group(Yreps$logt,   Yobs$logt,   
             df$region,   "Group PPC: log t by region")

Code
ppc_by_group(Yreps$logD,   Yobs$logD,   
             df$region,   "Group PPC: log duration by region")

Code
ppc_by_group(Yreps$logiti, Yobs$logiti, 
             df$region,   "Group PPC: logit i by region")

Code
ppc_by_group(Yreps$logN,   Yobs$logN,   
             df$year_bin, "Group PPC: logN by year bin")

Code
ppc_by_group(Yreps$logt,   Yobs$logt,   
             df$year_bin, "Group PPC: log t by year bin")

Code
ppc_by_group(Yreps$logD,   Yobs$logD,   
             df$year_bin, "Group PPC: log duration by year bin")

Code
ppc_by_group(Yreps$logiti, Yobs$logiti, 
             df$year_bin, "Group PPC: logit i by year bin")

6 Forward Simulation in Your Modeling Order

Code
# ---- store training specs to reproduce design rows at prediction time ----
# helper: capture basis attributes + column means/sds used for scaling
capture_bs_spec <- function(B_train) {
  list(
    degree         = attr(B_train, "degree"),
    knots          = attr(B_train, "knots"),
    Boundary.knots = attr(B_train, "Boundary.knots"),
    mu             = colMeans(as.matrix(B_train)),
    sd             = apply(as.matrix(B_train), 2, sd)
  )
}
# safe scaling using training mu/sd (0-sd -> 1)
scale_with <- function(M, mu, sd) {
  sd2 <- sd; sd2[!is.finite(sd2) | sd2 == 0] <- 1
  sweep(sweep(as.matrix(M), 2, mu, "-"), 2, sd2, "/")
}
# rebuild a bs basis using training attributes
bs_from_spec <- function(x_new, spec) {
  as.matrix(splines::bs(x_new,
                        degree = spec$degree,
                        knots = spec$knots,
                        Boundary.knots = spec$Boundary.knots,
                        intercept = FALSE))
}

# capture specs for the bases used in each equation (POST-cleaning, PRE-scaling)
spec_year_N <- capture_bs_spec(B_year_N)   # used in XN (after drop_bad, before scaling)
spec_year_t <- capture_bs_spec(B_year_t)   # used in Xt
spec_year_D <- capture_bs_spec(B_year_D)   # used in XD
spec_year_I <- capture_bs_spec(B_year_I)   # used in XI (year block)

spec_logt   <- capture_bs_spec(B_logt)     # used in XI tensor
spec_logD   <- capture_bs_spec(B_logD)     # used in XI tensor

# tensor spec: store column means/sds AFTER building tensor on training
# (we’ll rebuild the tensor from the two bases at prediction time, then scale with these)
tensor_train   <- tensor_kr(B_logt, B_logD)
spec_tensor_tD <- list(
  mu = colMeans(tensor_train),
  sd = apply(tensor_train, 2, sd)
)

# training means/sds for the centered “raw” covariates you used (x_logN, x_logt, x_logD)
center_sd_cont <- list(
  logN = c(mu = mean(df$logN), sd = sd(df$logN)),
  logt = c(mu = mean(df$logt), sd = sd(df$logt)),
  logD = c(mu = mean(df$logD), sd = sd(df$logD))
)
Code
# ---- extract posterior draws (arrays) ----
ex <- rstan::extract(fit, permuted = TRUE)

# fixed effects
beta_N <- ex$beta_N  # draws × pN
beta_t <- ex$beta_t  # draws × pt
beta_D <- ex$beta_D  # draws × pD
beta_I <- ex$beta_I  # draws × pI

# residual SDs
sigma_N <- ex$sigma_N  # draws
sigma_t <- ex$sigma_t
sigma_D <- ex$sigma_D
sigma_I <- ex$sigma_I

# random effects (from transformed parameters)
# u_reg: draws × 4 × R; u_cty: draws × 4 × C
u_reg <- ex$u_reg
u_cty <- ex$u_cty

n_draws <- nrow(beta_N)
stopifnot(n_draws > 0)
Code
# --- utilities for capturing & rebuilding spline/tensor blocks ---

# which columns to keep (finite & non-constant)
keep_idx <- function(M) {
  M <- as.matrix(M)
  apply(M, 2, function(col) {
    ok <- all(is.finite(col)); v <- var(col)
    ok && is.finite(v) && v > 0
  })
}

# build raw bs (with attributes intact) using *safe* df
safe_df <- function(x, target_df, min_df = 2) {
  nu <- length(unique(x))
  max(min_df, min(target_df, nu - 1))
}

build_bs_raw <- function(x, target_df) {
  df_use <- safe_df(x, target_df)
  splines::bs(x, df = df_use, intercept = FALSE)  # returns a "bs" object with attrs
}

# capture *full* spec: attrs + kept columns + scaling stats
capture_bs_spec <- function(bs_obj) {
  M <- as.matrix(bs_obj)
  keep <- keep_idx(M)
  list(
    degree         = attr(bs_obj, "degree"),
    knots          = attr(bs_obj, "knots"),
    Boundary.knots = attr(bs_obj, "Boundary.knots"),
    keep           = keep,
    mu             = colMeans(M[, keep, drop = FALSE]),
    sd             = apply(M[, keep, drop = FALSE], 2, sd)
  )
}

# rebuild a *raw* kept-column bs matrix for new x using the stored attrs
bs_raw_from_spec <- function(x_new, spec) {
  B <- splines::bs(x_new,
                   degree         = spec$degree,
                   knots          = spec$knots,
                   Boundary.knots = spec$Boundary.knots,
                   intercept = FALSE)
  as.matrix(B)[, spec$keep, drop = FALSE]
}

# scale with stored training mu/sd (guard sd=0)
scale_with <- function(M, mu, sd) {
  sd2 <- sd; sd2[!is.finite(sd2) | sd2 == 0] <- 1
  sweep(sweep(as.matrix(M), 2, mu, "-"), 2, sd2, "/")
}

# column-wise Khatri–Rao product
tensor_kr <- function(B1, B2) {
  out <- vector("list", ncol(B1))
  for (i in seq_len(ncol(B1))) out[[i]] <- B2 * B1[, i]
  do.call(cbind, out)
}

# capture tensor spec from two *raw* kept bases (unscaled)
capture_tensor_spec <- function(B1_raw, B2_raw) {
  Traw <- tensor_kr(B1_raw, B2_raw)
  keep <- keep_idx(Traw)
  list(
    keep = keep,
    mu   = colMeans(Traw[, keep, drop = FALSE]),
    sd   = apply(Traw[, keep, drop = FALSE], 2, sd)
  )
}

# Rebuild & scale tensor for new (B1_raw_new, B2_raw_new)
tensor_from_spec <- function(B1_raw_new, B2_raw_new, tens_spec) {
  Traw <- tensor_kr(B1_raw_new, B2_raw_new)
  Traw <- Traw[, tens_spec$keep, drop = FALSE]
  scale_with(Traw, tens_spec$mu, tens_spec$sd)
}

# ---- (RE)CAPTURE SPECS FROM YOUR TRAINING DATA ----
# Build RAW bases (with attrs) *before* dropping cols/scaling
B_year_N_raw <- build_bs_raw(df$year, 8)
B_year_t_raw <- build_bs_raw(df$year, 8)
B_year_D_raw <- build_bs_raw(df$year, 8)
B_year_I_raw <- build_bs_raw(df$year, 8)
B_logt_raw   <- build_bs_raw(df$logt, 6)
B_logD_raw   <- build_bs_raw(df$logD, 6)

# capture bs specs
spec_year_N <- capture_bs_spec(B_year_N_raw)
spec_year_t <- capture_bs_spec(B_year_t_raw)
spec_year_D <- capture_bs_spec(B_year_D_raw)
spec_year_I <- capture_bs_spec(B_year_I_raw)
spec_logt   <- capture_bs_spec(B_logt_raw)
spec_logD   <- capture_bs_spec(B_logD_raw)

# capture tensor spec (built from *raw kept* bases)
B_logt_keep_raw <- as.matrix(B_logt_raw)[, spec_logt$keep, drop = FALSE]
B_logD_keep_raw <- as.matrix(B_logD_raw)[, spec_logD$keep, drop = FALSE]
spec_tensor_tD  <- capture_tensor_spec(B_logt_keep_raw, B_logD_keep_raw)

# store training means/sds for raw covariates used as standardized scalars
center_sd_cont <- list(
  logN = c(mu = mean(df$logN), sd = sd(df$logN)),
  logt = c(mu = mean(df$logt), sd = sd(df$logt)),
  logD = c(mu = mean(df$logD), sd = sd(df$logD))
)
Code
simulate_forward <- function(n_sims = 400,
                             year,
                             region,
                             country,
                             use_same_draw_per_sim = TRUE) {

  if (length(year)   == 1) year   <- rep(year,   n_sims)
  if (length(region) == 1) region <- rep(region, n_sims)
  if (length(country)== 1) country<- rep(country,n_sims)
  stopifnot(length(year) == n_sims, length(region) == n_sims,
            length(country) == n_sims)

  region  <- factor(region,  levels = levels(df$region))
  country <- factor(country, levels = levels(df$country))
  r_id <- as.integer(region); c_id <- as.integer(country)

  # posterior draws (already extracted earlier)
  n_draws <- nrow(beta_N); stopifnot(n_draws > 0)

  if (use_same_draw_per_sim) {
    draw_id <- sample(seq_len(n_draws), n_sims, replace = TRUE)
  } else {
    draw_id <- cbind(
      N  = sample(seq_len(n_draws), n_sims, replace = TRUE),
      t  = sample(seq_len(n_draws), n_sims, replace = TRUE),
      D  = sample(seq_len(n_draws), n_sims, replace = TRUE),
      Ip = sample(seq_len(n_draws), n_sims, replace = TRUE)
    )
  }

  logN <- numeric(n_sims); logt <- numeric(n_sims)
  logD <- numeric(n_sims); logi <- numeric(n_sims)

  for (i in seq_len(n_sims)) {
    di_N  <- if (is.matrix(draw_id)) draw_id[i, "N"]  else draw_id[i]
    di_t  <- if (is.matrix(draw_id)) draw_id[i, "t"]  else draw_id[i]
    di_D  <- if (is.matrix(draw_id)) draw_id[i, "D"]  else draw_id[i]
    di_I  <- if (is.matrix(draw_id)) draw_id[i, "Ip"] else draw_id[i]

    aN <- u_reg[di_N, 1, r_id[i]] + u_cty[di_N, 1, c_id[i]]
    at <- u_reg[di_t, 2, r_id[i]] + u_cty[di_t, 2, c_id[i]]
    aD <- u_reg[di_D, 3, r_id[i]] + u_cty[di_D, 3, c_id[i]]
    aI <- u_reg[di_I, 4, r_id[i]] + u_cty[di_I, 4, c_id[i]]

    # ---- X_N(row): year spline only ----
    B_yN_raw <- bs_raw_from_spec(year[i], spec_year_N)
    XN_row <- c(1, scale_with(B_yN_raw, spec_year_N$mu, spec_year_N$sd))
    etaN <- sum(XN_row * beta_N[di_N, ]) + aN
    logN[i] <- rnorm(1, etaN, sigma_N[di_N])

    # ---- X_t(row): year spline + std logN ----
    B_yt_raw <- bs_raw_from_spec(year[i], spec_year_t)
    Xt_row <- c(1,
                scale_with(B_yt_raw, spec_year_t$mu, spec_year_t$sd),
                (logN[i] - center_sd_cont$logN["mu"]) / center_sd_cont$logN["sd"])
    etat <- sum(Xt_row * beta_t[di_t, ]) + at
    logt[i] <- rnorm(1, etat, sigma_t[di_t])

    # ---- X_D(row): year spline + std logN, logt ----
    B_yD_raw <- bs_raw_from_spec(year[i], spec_year_D)
    XD_row <- c(1,
                scale_with(B_yD_raw, spec_year_D$mu, spec_year_D$sd),
                (logN[i] - center_sd_cont$logN["mu"]) / center_sd_cont$logN["sd"],
                (logt[i] - center_sd_cont$logt["mu"]) / center_sd_cont$logt["sd"])
    etaD <- sum(XD_row * beta_D[di_D, ]) + aD
    logD[i] <- rnorm(1, etaD, sigma_D[di_D])

    # ---- X_I(row): year spline + std logN,logt,logD + tensor(logt,logD) ----
    B_yI_raw   <- bs_raw_from_spec(year[i], spec_year_I)
    B_logt_raw_new <- bs_raw_from_spec(logt[i], spec_logt)
    B_logD_raw_new <- bs_raw_from_spec(logD[i], spec_logD)
    T_scaled <- tensor_from_spec(B_logt_raw_new, B_logD_raw_new, spec_tensor_tD)

    XI_row <- c(1,
                scale_with(B_yI_raw, spec_year_I$mu, spec_year_I$sd),
                (logN[i] - center_sd_cont$logN["mu"]) / center_sd_cont$logN["sd"],
                (logt[i] - center_sd_cont$logt["mu"]) / center_sd_cont$logt["sd"],
                (logD[i] - center_sd_cont$logD["mu"]) / center_sd_cont$logD["sd"],
                T_scaled)
    etai <- sum(XI_row * beta_I[di_I, ]) + aI
    logi[i] <- rnorm(1, etai, sigma_I[di_I])
  }

  tibble(
    sim_id  = seq_len(n_sims),
    year    = year,
    region  = as.character(region),
    country = as.character(country),
    N_pop             = pmax(1e3, exp(logN)),
    t_peak_weeks      = pmax(1,   exp(logt)),
    duration_weeks    = pmax(2,   exp(logD)),
    i_peak_per_person = pmin(0.95, pmax(1e-6, plogis(logi)))
  ) |>
    mutate(duration_weeks = pmax(duration_weeks, t_peak_weeks + 1))
}
Code
# 400 simulations for one stratum
sim1 <- simulate_forward(n_sims = 400, year = 2015, region = "R2", country = "C5")

# simulate with the empirical mix of (year, region, country) from your data
keys <- df %>% select(year, region, country) %>% slice_sample(n = 1000, replace = TRUE)
sim_meta <- simulate_forward(n_sims = nrow(keys),
                            year = keys$year,
                            region = as.character(keys$region),
                            country = as.character(keys$country))

Quick visual check

Code
hist_summ <- df %>% 
  # filter(year == 2015, region =="R2", country == "C5") %>% 
  transmute(source = "historical",
            N_pop = exp(logN),
            t_peak_weeks = exp(logt),
            duration_weeks = exp(logD),
            i_peak_per_person = plogis(logit_i))

# sim_summ <- sim1 %>% mutate(source = "simulated")
sim_summ <- sim_meta %>% mutate(source = "simulated")

bind_rows(hist_summ, sim_summ) %>%
  pivot_longer(c(N_pop, t_peak_weeks, duration_weeks, i_peak_per_person),
               names_to = "var", values_to = "value") %>%
  ggplot(aes(value, fill = source)) +
  geom_density(alpha = 0.35) +
  facet_wrap(~var, scales = "free") +
  labs(title = "Historical vs Forward Simulated (matched stratum)",
       x = NULL, y = "density", fill = NULL)

7 Weekly Incidence & Counts

Code
inc_curve_exp <- function(week, t_peak, i_peak, r, d){
  ifelse(
    week <= t_peak,
    i_peak * exp( r * (week - t_peak)),
    i_peak * exp(-d * (week - t_peak))
  )
}
rd_from_D <- function(D, eps=0.05){ L <- log(1/eps); r <- d <- 2*L / D; c(r=r, d=d) }
inc_curve_tri <- function(week, t_peak, i_peak, D){
  ifelse(
    week <= t_peak, i_peak * (week / t_peak),
    ifelse(week > t_peak & week <= D, i_peak * (1 - (week - t_peak)/max(D - t_peak, 1e-6)), 0)
  )
}
simulate_weekly_counts <- function(sim_meta, Tmax=50, shape_mix=0.6, phi_nb=0.35, eps=0.05){
  weeks <- 1:Tmax; Sims <- vector("list", nrow(sim_meta))
  shape_flag <- rbinom(nrow(sim_meta), 1, prob=shape_mix)
  for (i in seq_len(nrow(sim_meta))){
    N <- sim_meta$N_pop[i]; tpk <- sim_meta$t_peak_weeks[i]
    D <- sim_meta$duration_weeks[i]; ipk <- sim_meta$i_peak_per_person[i]
    if (shape_flag[i]==1){ rd <- rd_from_D(D, eps); inc <- inc_curve_exp(weeks, tpk, ipk, rd['r'], rd['d']); shape <- "exp"
    } else { inc <- inc_curve_tri(weeks, tpk, ipk, D); shape <- "tri" }
    mu <- N * pmax(inc, 0); y <- rnbinom_mu_phi(length(weeks), mu, phi_nb)
    Sims[[i]] <- tibble(sim_id=i, week=weeks, y=y, mu=mu, shape=shape, N=N, t_peak=tpk, D=D, i_peak=ipk)
  }
  bind_rows(Sims)
}
Code
sim_weekly <- simulate_weekly_counts(sim_meta, Tmax=50, shape_mix=0.6, phi_nb=0.35)
sim_sum <- sim_weekly %>% group_by(week) %>% summarise(y_mean=mean(y), y_q05=quantile(y,0.05), y_q95=quantile(y,0.95), .groups="drop")

ggplot(sim_sum, aes(week, y_mean)) +
  geom_ribbon(aes(ymin=y_q05, ymax=y_q95), alpha=0.2) +
  geom_line() + labs(title="Simulated weekly counts (mean ± 90% envelope)", y="cases")

Code
# ---- helpers ----
# NB2 sampler parameterized by mean 'mu' and overdispersion 'phi' (Var = mu + phi * mu^2)
rnbinom_mu_phi <- function(n, mu, phi) {
  # size = 1/phi; allow vector 'mu'
  size <- 1 / pmax(1e-12, phi)
  # rnbinom recycles 'size' if length 1; replicate if you want elementwise
  if (length(mu) != n) stop("rnbinom_mu_phi: length(mu) must equal n")
  stats::rnbinom(n, size = size, mu = pmax(0, mu))
}

# Map total duration D (full width) to symmetric rise/decay rates (r, d)
# Interpretation: at distance D/2 from the peak, the exponential curve is 'eps' of the peak.
# So exp(-r * (D/2)) = eps and exp(-d * (D/2)) = eps  => r = d = 2*log(1/eps)/D
rd_from_D <- function(D, eps = 0.05) {
  L <- log(1 / pmin(pmax(eps, 1e-6), 0.99))
  rate <- 2 * L / pmax(D, 1e-6)
  c(r = rate, d = rate)
}

# Exponential rise/decay incidence curve (per person per week)
# Peaks at 'i_peak' when week == t_peak; otherwise decays exponentially.
inc_curve_exp <- function(week, t_peak, i_peak, r, d) {
  inc <- ifelse(
    week <= t_peak,
    i_peak * exp(r * (week - t_peak)),
    i_peak * exp(-d * (week - t_peak))
  )
  # keep in [0, 1-eps] to avoid impossible per-person probabilities
  pmin(pmax(inc, 0), 0.99)
}

# Triangular (piecewise linear) incidence curve
# Rises linearly 0 -> i_peak up to t_peak, then declines to 0 by week == D.
inc_curve_tri <- function(week, t_peak, i_peak, D) {
  up   <- i_peak * (week / pmax(t_peak, 1))
  down <- i_peak * (1 - (week - t_peak) / pmax(D - t_peak, 1e-6))
  inc  <- ifelse(week <= t_peak, up,
                 ifelse(week <= D, pmax(down, 0), 0))
  pmin(pmax(inc, 0), 0.99)
}

# ---- main simulator ----
# sim_meta must have columns: N_pop, t_peak_weeks, duration_weeks, i_peak_per_person
simulate_weekly_counts <- function(sim_meta,
                                   weeks = NULL,      # e.g., 1:Tmax or a custom integer vector
                                   Tmax  = 50,        # used only if 'weeks' is NULL
                                   shape_mix = 0.6,   # Pr(exponential); 1-shape_mix = triangular
                                   phi_nb = 0.35,     # NB2 overdispersion (common across weeks)
                                   eps    = 0.05) {   # exp tail fraction at +/- D/2
  stopifnot(all(c("N_pop","t_peak_weeks","duration_weeks",
                  "i_peak_per_person") %in% names(sim_meta)))
  if (is.null(weeks)) weeks <- seq_len(Tmax)
  weeks <- as.integer(weeks)

  nS <- nrow(sim_meta)
  shape_flag <- rbinom(nS, 1, prob = shape_mix)  # 1=exp, 0=tri

  out <- vector("list", nS)

  for (i in seq_len(nS)) {
    N   <- sim_meta$N_pop[i]
    tpk <- sim_meta$t_peak_weeks[i]
    D   <- sim_meta$duration_weeks[i]
    ipk <- sim_meta$i_peak_per_person[i]

    if (shape_flag[i] == 1) {
      rd  <- rd_from_D(D, eps = eps)               # symmetric rates from duration
      inc <- inc_curve_exp(weeks, tpk, ipk, rd["r"], rd["d"])
      shape <- "exp"
    } else {
      inc <- inc_curve_tri(weeks, tpk, ipk, D)     # zero beyond week D
      shape <- "tri"
    }

    mu_week <- pmax(0, N * inc)                    # expected weekly cases
    y_week  <- rnbinom_mu_phi(length(weeks), mu = mu_week, phi = phi_nb)

    out[[i]] <- tibble(
      sim_id   = i,
      week     = weeks,
      shape    = shape,
      N_pop    = N,
      t_peak   = tpk,
      D        = D,
      i_peak   = ipk,
      mu       = mu_week,
      cases    = y_week
    )
  }

  bind_rows(out)
}

# ---- example usage ----
# Using 'sim_meta' from your forward simulator:
# (A) fixed horizon 1..50
sim_weekly <- simulate_weekly_counts(sim_meta, Tmax = 50, 
                                     shape_mix = 0.6, phi_nb = 0.35)

# (B) custom calendar weeks (e.g., -10..60 relative to outbreak start)
# sim_weekly <- simulate_weekly_counts(sim_meta, weeks = -10:60)

# Summaries for plotting (pointwise across simulations)
sim_sum <- sim_weekly %>%
  group_by(week) %>%
  summarise(
    cases_mean = mean(cases),
    cases_q05  = quantile(cases, 0.05),
    cases_q95  = quantile(cases, 0.95),
    .groups = "drop"
  )

ggplot(sim_sum, aes(week, cases_mean)) +
  geom_ribbon(aes(ymin = cases_q05, ymax = cases_q95), alpha = 0.2) +
  geom_line() +
  labs(title = "Simulated weekly cases (mean and 90% envelope)", x = "week", y = "cases") +
  theme_minimal()

Code
sim_totals <- sim_weekly %>%
  group_by(sim_id) %>%
  summarise(total_cases = sum(cases), total_mu = sum(mu), .groups = "drop") %>%
  left_join(sim_meta %>% mutate(sim_id = row_number()), by = "sim_id") %>%
  mutate(attack_rate_obs = total_cases / N_pop,
         attack_rate_mu  = total_mu    / N_pop)
Code
## ===========================
## Weekly curves: extra shapes
## ===========================

# NB2 sampler (unchanged)
rnbinom_mu_phi <- function(n, mu, phi) {
  size <- 1 / pmax(1e-12, phi)
  if (length(mu) != n) stop("rnbinom_mu_phi: length(mu) must equal n")
  stats::rnbinom(n, size = size, mu = pmax(0, mu))
}

# Symmetric exp: map duration D to (r, d) with tail fraction eps at +/- D/2
rd_from_D_sym <- function(D, eps = 0.05) {
  L <- log(1 / pmin(pmax(eps, 1e-6), 0.99))
  rate <- 2 * L / pmax(D, 1e-6)
  c(r = rate, d = rate)
}

# Asymmetric exp: split D into a*D before-peak and (1-a)*D after-peak
# Tail fraction eps at t_peak - aD and t_peak + (1-a)D
rd_from_D_asym <- function(D, a = 0.5, eps = 0.05) {
  a <- pmin(pmax(a, 1e-3), 1 - 1e-3)     # keep a in (0,1)
  L <- log(1 / pmin(pmax(eps, 1e-6), 0.99))
  r <- L / pmax(a * D,      1e-6)        # growth rate (left side)
  d <- L / pmax((1 - a) * D, 1e-6)       # decay rate (right side)
  c(r = r, d = d)
}

# Exponential (symmetric) per-person incidence curve
inc_curve_exp_sym <- function(week, t_peak, i_peak, D, eps = 0.05) {
  rd  <- rd_from_D_sym(D, eps)
  inc <- ifelse(week <= t_peak,
                i_peak * exp(rd["r"] * (week - t_peak)),
                i_peak * exp(-rd["d"] * (week - t_peak)))
  pmin(pmax(inc, 0), 0.99)
}

# Exponential (asymmetric) per-person incidence curve
inc_curve_exp_asym <- function(week, t_peak, i_peak, D, a = 0.5, eps = 0.05) {
  rd  <- rd_from_D_asym(D, a, eps)
  inc <- ifelse(week <= t_peak,
                i_peak * exp(rd["r"] * (week - t_peak)),
                i_peak * exp(-rd["d"] * (week - t_peak)))
  pmin(pmax(inc, 0), 0.99)
}

# Triangular per-person incidence curve (unchanged)
inc_curve_tri <- function(week, t_peak, i_peak, D) {
  up   <- i_peak * (week / pmax(t_peak, 1))
  down <- i_peak * (1 - (week - t_peak) / pmax(D - t_peak, 1e-6))
  inc  <- ifelse(week <= t_peak, up,
                 ifelse(week <= D, pmax(down, 0), 0))
  pmin(pmax(inc, 0), 0.99)
}

Gamma renewal curve

Code
# Gamma GT weights (stable, normalized)
gt_weights <- function(max_lag = 40, mean = 5, sd = 2) {
  mean <- pmax(mean, 1e-6); sd <- pmax(sd, 1e-6)
  k <- (mean / sd)^2; theta <- sd^2 / mean
  lags <- 1:max_lag
  dens <- pgamma(lags + 0.5, shape = k, scale = theta) -
          pgamma(lags - 0.5, shape = k, scale = theta)
  dens[dens < 0] <- 0
  dens <- if (sum(dens) > 0) dens / sum(dens) else dgamma(lags, k, scale = theta) / sum(dgamma(lags, k, scale = theta))
  dens
}

# Map instantaneous growth rate g(t) to Rt(t) for a Gamma GT
Rt_from_g_gamma <- function(g, mean = 5, sd = 2, R_cap = 30) {
  k <- (mean / sd)^2; theta <- sd^2 / mean
  Rt <- ifelse(g >= 0, (1 + theta * g)^k, (1 - theta * (-g))^k)
  Rt[!is.finite(Rt)] <- 1
  pmax(pmin(Rt, R_cap), 1e-6)
}

# Derive r and d from duration D and asymmetry a (fraction of D before the peak)
rates_from_D <- function(D, a = 0.5, eps = 0.05) {
  L <- log(1 / pmin(pmax(eps, 1e-6), 0.99))
  r <- L / pmax(a * D, 1e-6)          # pre-peak growth rate
  d <- L / pmax((1 - a) * D, 1e-6)    # post-peak decay rate
  c(r = r, d = d)
}

# Renewal incidence that *forces* a tent peaked at t_peak by using piecewise constant g(t)
inc_curve_renewal_gamma <- function(week, t_peak, i_peak, D,
                                    a = 0.5, eps = 0.05,
                                    gt_mean = 5, gt_sd = 2,
                                    max_lag = 40, pre_buffer = 10, R_cap = 30) {
  # ensure the window includes a bit before/after the peak
  if (min(week) > t_peak - pre_buffer || max(week) < t_peak + 1) {
    # extend internally but return only original weeks
    w_all <- seq(floor(t_peak - pre_buffer), ceiling(max(max(week), t_peak + D)), by = 1)
    restrict_back <- TRUE
  } else {
    w_all <- week
    restrict_back <- FALSE
  }

  # growth/decay rates from D, a
  rd <- rates_from_D(D, a = a, eps = eps)
  r <- rd["r"]; d <- rd["d"]

  # piecewise instantaneous growth g(t)
  g <- ifelse(w_all <= t_peak, r, -d)

  # Rt(t) from g(t)
  Rt <- Rt_from_g_gamma(g, mean = gt_mean, sd = gt_sd, R_cap = R_cap)

  # gamma weights
  w <- gt_weights(max_lag = max_lag, mean = gt_mean, sd = gt_sd)

  # discrete renewal
  Tn <- length(w_all)
  i <- rep(0, Tn)
  i[1] <- 1e-8
  for (t in 2:Tn) {
    lag_max <- min(t - 1, length(w))
    lambda <- sum(i[t - (1:lag_max)] * w[1:lag_max])
    i[t] <- Rt[t] * lambda
    if (!is.finite(i[t]) || i[t] < 0) i[t] <- 0
  }

  # scale to hit peak exactly at t_peak (nearest index)
  t_idx_peak <- which.min(abs(w_all - t_peak))
  peak_val <- i[t_idx_peak]
  if (!is.finite(peak_val) || peak_val <= 0) return(rep(0, length(week)))
  s <- i_peak / peak_val
  i_scaled <- pmin(pmax(s * i, 0), 0.99)

  # optionally truncate tiny tails
  i_scaled[i_scaled < eps * i_peak] <- 0

  # return on requested grid
  if (restrict_back) {
    out <- approx(x = w_all, y = i_scaled, xout = week, method = "constant", rule = 2)$y
  } else {
    out <- i_scaled
  }
  out
}
Code
# ============================
# Simulate weekly case counts
# ============================
simulate_weekly_counts <- function(sim_meta,
                                   weeks      = NULL,   # integer vector of week indices; default 1:Tmax
                                   Tmax       = 50,
                                   shape_mode = c("mix","exp_sym","exp_asym","tri","renewal_gamma"),
                                   mix_probs  = c(exp_sym = 0.5, tri = 0.5),  # used only if shape_mode=="mix"
                                   # asymmetric exp parameters
                                   asym_a     = 0.5,    # fraction of D before the peak (0<a<1)
                                   eps        = 0.05,   # tail height fraction for exp shapes & truncation in renewal
                                   # NB2 dispersion
                                   phi_nb     = 0.35,
                                   # renewal parameters
                                   gt_mean    = 5, gt_sd = 2, max_lag = 40) {

  stopifnot(all(c("N_pop","t_peak_weeks","duration_weeks","i_peak_per_person") %in% names(sim_meta)))
  if (is.null(weeks)) weeks <- seq_len(Tmax) else weeks <- as.integer(weeks)

  shape_mode <- match.arg(shape_mode)

  # if mixture: draw shape per simulation using 'mix_probs'
  if (shape_mode == "mix") {
    mix_probs <- mix_probs[ mix_probs > 0 ]
    mix_probs <- mix_probs / sum(mix_probs)
    shapes <- sample(names(mix_probs), size = nrow(sim_meta), replace = TRUE, prob = mix_probs)
  } else {
    shapes <- rep(shape_mode, nrow(sim_meta))
  }

  out <- vector("list", nrow(sim_meta))

  for (i in seq_len(nrow(sim_meta))) {
    N   <- sim_meta$N_pop[i]
    tpk <- sim_meta$t_peak_weeks[i]
    D   <- sim_meta$duration_weeks[i]
    ipk <- sim_meta$i_peak_per_person[i]

    inc <- switch(shapes[i],
      "exp_sym" = inc_curve_exp_sym(weeks, tpk, ipk, D, eps = eps),
      "exp_asym" = inc_curve_exp_asym(weeks, tpk, ipk, D, a = asym_a, eps = eps),
      "tri" = inc_curve_tri(weeks, tpk, ipk, D),
      "renewal_gamma" = inc_curve_renewal_gamma(weeks, tpk, ipk, D,
                                                a = asym_a, eps = eps,
                                                gt_mean = gt_mean, gt_sd = gt_sd, max_lag = max_lag),
      stop("Unknown shape: ", shapes[i])
    )

    mu_week <- pmax(0, N * inc)
    y_week  <- rnbinom_mu_phi(length(weeks), mu = mu_week, phi = phi_nb)

    out[[i]] <- tibble(
      sim_id = i, week = weeks, shape = shapes[i],
      N_pop = N, t_peak = tpk, D = D, i_peak = ipk,
      mu = mu_week, cases = y_week
    )
  }

  bind_rows(out)
}
Code
# Using your forward-simulated meta-parameters 'sim_meta'
# 1) Asymmetric exponential (e.g., faster rise, slower decay: a = 0.3 ⇒ longer right tail)
sim_asym  <- simulate_weekly_counts(sim_meta, Tmax = 50, 
                                    shape_mode = "exp_asym", asym_a = 0.3)

# 2) Renewal with Gamma GT (mean 5d, sd 2d); use weeks as, say, 1..60
sim_ren   <- simulate_weekly_counts(sim_meta, weeks = 1:60, 
                                    shape_mode = "renewal_gamma",
                                    gt_mean = 5, gt_sd = 2, eps = 0.02)

# 3) Mixture across four shapes
sim_mix4  <- simulate_weekly_counts(sim_meta, Tmax = 60, shape_mode = "mix",
                                    mix_probs = c(exp_sym = 0.3, 
                                                  exp_asym = 0.3, tri = 0.2, 
                                                  renewal_gamma = 0.2))

# Summarize/plot (as before)
summ <- sim_ren %>%
  group_by(week) %>%
  summarise(cases_mean = mean(cases),
            cases_q05  = quantile(cases, 0.05),
            cases_q95  = quantile(cases, 0.95), .groups = "drop")

ggplot(summ, aes(week, cases_mean)) +
  geom_ribbon(aes(ymin = cases_q05, ymax = cases_q95), alpha = 0.2) +
  geom_line() +
  labs(title = "Gamma–renewal simulated weekly cases", x = "week", y = "cases") +
  theme_minimal()

8 References