pmmh {bayesSSM} | R Documentation |
Particle Marginal Metropolis-Hastings (PMMH) for State-Space Models
Description
This function implements a Particle Marginal Metropolis-Hastings (PMMH) algorithm to perform Bayesian inference in state-space models. It first runs a pilot chain to tune the proposal distribution and the number of particles for the particle filter, and then runs the main PMMH chain.
Usage
pmmh(
y,
m,
init_fn,
transition_fn,
log_likelihood_fn,
log_priors,
pilot_init_params,
burn_in,
num_chains = 4,
obs_times = NULL,
algorithm = c("SISAR", "SISR", "SIS"),
resample_fn = c("stratified", "systematic", "multinomial"),
param_transform = NULL,
tune_control = default_tune_control(),
verbose = FALSE,
return_latent_state_est = FALSE,
seed = NULL,
num_cores = 1
)
Arguments
y |
A numeric vector or matrix of observations. Each row represents an observation at a time step. |
m |
An integer specifying the total number of MCMC iterations. |
init_fn |
A function that initializes the particle states. It should take 'num_particles' as an argument for initializing the particles and return a vector or matrix of initial particle states. It can include any model-specific parameters as named arguments. |
transition_fn |
A function describing the state transition model. It should take 'particles' as an argument and return the propagated particles. The function can optionally depend on time by including a time step argument 't'. It can include any model-specific parameters as named arguments. |
log_likelihood_fn |
A function that computes the log-likelihoods for the particles. It should take a 'y' argument for the observations, the current particles, and return a numeric vector of log-likelihood values. The function can optionally depend on time by including a time step argument 't'. It can include any model-specific parameters as named arguments. |
log_priors |
A list of functions for computing the log-prior of each parameter. |
pilot_init_params |
A list of initial parameter values. Should be a list
of length |
burn_in |
An integer indicating the number of initial MCMC iterations to discard as burn-in. |
num_chains |
An integer specifying the number of PMMH chains to run. |
obs_times |
A numeric vector indicating the time points at which
observations in |
algorithm |
A character string specifying the particle filtering
algorithm to use. Must be one of |
resample_fn |
A character string specifying the resampling method.
Must be one of |
param_transform |
An optional character vector that specifies the
transformation applied to each parameter before proposing. The proposal is
made using a multivariate normal distribution on the transformed scale.
Parameters are then mapped back to their original scale before evaluation.
Currently supports |
tune_control |
A list of pilot tuning controls
(e.g., |
verbose |
A logical value indicating whether to print information about
pilot_run tuning. Defaults to |
return_latent_state_est |
A logical value indicating whether to return
the latent state estimates for each time step. Defaults to |
seed |
An optional integer to set the seed for reproducibility. |
num_cores |
An integer specifying the number of cores to use for
parallel processing. Defaults to 1. Each chain is assigned to its own core,
so the number of cores cannot exceed the number of chains
( |
Details
The PMMH algorithm is essentially a Metropolis Hastings algorithm
where instead of using the exact likelihood it instead uses an estimated
using likelihood using a particle filter (see also
particle_filter
). Values are proposed using a multivariate
normal distribution in the transformed space. The proposal covariance is
estimated using the pilot chain.
Value
A list containing:
theta_chain
A dataframe of post burn-in parameter samples.
latent_state_chain
If
return_latent_state_est
isTRUE
, a list of matrices containing the latent state estimates for each time step.diagnostics
Diagnostics containing ESS and Rhat for each parameter (see
ess
andrhat
for documentation).
References
Andrieu et al. (2010). Particle Markov chain Monte Carlo methods. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 72(3):269–342. doi: 10.1111/j.1467-9868.2009.00736.x
Examples
init_fn <- function(num_particles) {
rnorm(num_particles, mean = 0, sd = 1)
}
transition_fn <- function(particles, phi, sigma_x) {
phi * particles + sin(particles) +
rnorm(length(particles), mean = 0, sd = sigma_x)
}
log_likelihood_fn <- function(y, particles, sigma_y) {
dnorm(y, mean = cos(particles), sd = sigma_y, log = TRUE)
}
log_prior_phi <- function(phi) {
dnorm(phi, mean = 0, sd = 1, log = TRUE)
}
log_prior_sigma_x <- function(sigma) {
dexp(sigma, rate = 1, log = TRUE)
}
log_prior_sigma_y <- function(sigma) {
dexp(sigma, rate = 1, log = TRUE)
}
log_priors <- list(
phi = log_prior_phi,
sigma_x = log_prior_sigma_x,
sigma_y = log_prior_sigma_y
)
# Generate data
t_val <- 10
x <- numeric(t_val)
y <- numeric(t_val)
phi <- 0.8
sigma_x <- 1
sigma_y <- 0.5
init_state <- rnorm(1, mean = 0, sd = 1)
x[1] <- phi * init_state + sin(init_state) + rnorm(1, mean = 0, sd = sigma_x)
y[1] <- x[1] + rnorm(1, mean = 0, sd = sigma_y)
for (t in 2:t_val) {
x[t] <- phi * x[t - 1] + sin(x[t - 1]) + rnorm(1, mean = 0, sd = sigma_x)
y[t] <- cos(x[t]) + rnorm(1, mean = 0, sd = sigma_y)
}
x <- c(init_state, x)
# Should use much higher MCMC iterations in practice (m)
pmmh_result <- pmmh(
y = y,
m = 1000,
init_fn = init_fn,
transition_fn = transition_fn,
log_likelihood_fn = log_likelihood_fn,
log_priors = log_priors,
pilot_init_params = list(
c(phi = 0.8, sigma_x = 1, sigma_y = 0.5),
c(phi = 1, sigma_x = 0.5, sigma_y = 1)
),
burn_in = 100,
num_chains = 2,
param_transform = list(
phi = "identity",
sigma_x = "log",
sigma_y = "log"
),
tune_control = default_tune_control(pilot_m = 500, pilot_burn_in = 100)
)
# Convergence warning is expected with such low MCMC iterations.
# Suppose we have data for t=1,2,3,5,6,7,8,9,10 (i.e., missing at t=4)
obs_times <- c(1, 2, 3, 5, 6, 7, 8, 9, 10)
y <- y[obs_times]
# Specify observation times in the pmmh using obs_times
pmmh_result <- pmmh(
y = y,
m = 1000,
init_fn = init_fn,
transition_fn = transition_fn,
log_likelihood_fn = log_likelihood_fn,
log_priors = log_priors,
pilot_init_params = list(
c(phi = 0.8, sigma_x = 1, sigma_y = 0.5),
c(phi = 1, sigma_x = 0.5, sigma_y = 1)
),
burn_in = 100,
num_chains = 2,
obs_times = obs_times,
param_transform = list(
phi = "identity",
sigma_x = "log",
sigma_y = "log"
),
tune_control = default_tune_control(pilot_m = 500, pilot_burn_in = 100)
)