load_vae {vmsae} | R Documentation |
Load Pretrained VAE Decoder
Description
Load a pretrained Variational Autoencoder (VAE) decoder from disk. This function reads the saved PyTorch model weights and corresponding GEOID list, and constructs a Decoder
S4 object with the loaded parameters.
Usage
load_vae(model_name, save_dir = NULL)
Arguments
model_name |
Character. The name of the trained VAE model (without |
save_dir |
Character. The directory where the trained VAE model is saved. Defaults to the current directory if |
Details
This function assumes the model was trained and saved using train_vae()
, and that the decoder weights are stored in a file compatible with torch::load()
(via reticulate). It extracts the decoder input/output weights and biases, along with region GEOIDs, and returns them as an S4 object of class Decoder
.
Value
An object of class Decoder
, containing the decoder weights and region identifiers.
Examples
## Not run:
library(vmsae)
# this function is time consuming for the first run
install_environment()
load_environment()
decoder <- load_vae(model_name = "mo_county")
## End(Not run)