train_vae {vmsae} | R Documentation |
Train VAE for CAR Prior
Description
Trains a Variational Autoencoder (VAE) to learn the spatial structure implied by the Conditional Autoregressive (CAR) prior. The trained VAE parameters are saved and can later be used as a generator within Hamiltonian Monte Carlo (HMC) sampling.
Usage
train_vae(
W,
GEOID,
model_name,
save_dir,
n_samples = 10000,
batch_size = 256,
epoch = 10000,
lr_init = 0.001,
lr_min = 1e-07,
verbose = TRUE,
use_gpu = TRUE
)
Arguments
W |
Matrix. A proximity or adjacency matrix representing spatial relationships. |
GEOID |
Character vector. Identifiers for spatial units (e.g., region or area codes). |
model_name |
Character. The name of the trained VAE model. |
save_dir |
Character. Directory to save the trained VAE model and associated metadata. Defaults to the current working directory. |
n_samples |
Integer. Number of samples to draw from the prior for training. Default is |
batch_size |
Integer. Batch size for VAE training. Default is |
epoch |
Integer. Number of training epochs. Default is |
lr_init |
Numeric. Initial learning rate. Default is |
lr_min |
Numeric. Minimum learning rate at the final epoch. Default is |
verbose |
Logical; if |
use_gpu |
Boolean. Use GPU if available. Default is |
Details
The function requires a configured Python environment via the reticulate interface,
with VAE training implemented in Python. It uses py$train_vae()
defined in the
sourced Python modules (see load_environment
).
Value
A named list containing:
loss |
Total training loss |
RCL |
Reconstruction error |
KLD |
Kullback–Leibler divergence |
Examples
## Not run:
library(vmsae)
library(sf)
# this function is time consuming for the first run
install_environment()
load_environment()
acs_data <- read_sf(system.file("example", "mo_county.shp", package = "vmsae"))
W <- readRDS(system.file("example", "W.Rds", package = "vmsae"))
loss <- train_vae(W = W,
GEOID = acs_data$GEOID,
model_name = "test",
save_dir = tempdir(),
n_samples = 1000, # set to larger values in practice, e.g. 10000.
batch_size = 256,
epoch = 1000) # set to larger values in practice, e.g. 10000.
## End(Not run)