generate_randomizations_mc {fastrerandomize}R Documentation

Draws a random sample of acceptable randomizations from all possible complete randomizations using Monte Carlo sampling

Description

This function performs sampling with replacement to generate randomizations in a memory-efficient way. It processes randomizations in batches to avoid memory issues and filters them based on covariate balance. The function uses JAX for fast computation and memory management.

Usage

generate_randomizations_mc(
  n_units,
  n_treated,
  X,
  randomization_accept_prob = 1,
  threshold_func = NULL,
  max_draws = 1e+05,
  batch_size = 1000,
  approximate_inv = TRUE,
  verbose = TRUE,
  conda_env = "fastrerandomize",
  conda_env_required = TRUE
)

Arguments

n_units

An integer specifying the total number of experimental units.

n_treated

An integer specifying the number of units to be assigned to treatment.

X

A numeric matrix of covariates used for balance checking. Cannot be NULL.

randomization_accept_prob

A numeric value between 0 and 1 specifying the probability threshold for accepting randomizations based on balance. Default is 1

threshold_func

A JAX function that computes a balance measure for each randomization. Must be vectorized using jax$vmap with in_axes = list(NULL, 0L, NULL, NULL), and inputs covariates (matrix of X), treatment_assignment (vector of 0s and 1s), n0 (scalar), n1 (scalar). Default is VectorizedFastHotel2T2 which uses Hotelling's T-squared statistic.

max_draws

An integer specifying the maximum number of randomizations to draw.

batch_size

An integer specifying how many randomizations to process at once. Lower values use less memory but may be slower.

approximate_inv

A logical value indicating whether to use an approximate inverse (diagonal of the covariance matrix) instead of the full matrix inverse when computing balance metrics. This can speed up computations for high-dimensional covariates. Default is TRUE.

verbose

A logical value indicating whether to print detailed information about batch processing progress, and GPU memory usage. Default is FALSE.

conda_env

A character string specifying the name of the conda environment to use via reticulate. Default is "fastrerandomize".

conda_env_required

A logical indicating whether the specified conda environment must be strictly used. If TRUE, an error is thrown if the environment is not found. Default is TRUE.

Details

The function works by:

  1. Generating batches of random permutations.

  2. Computing balance measures for each permutation using the provided threshold function.

  3. Keeping only the top permutations that meet the acceptance probability threshold.

  4. Managing memory by clearing unused objects and caches between batches.

The function uses smaller data types (int8, float16) where possible to reduce memory usage. It also includes assertions to verify array shapes and dimensions throughout.

Value

The function returns a list with two elements: candidate_randomizations: an array of randomization vectors M_candidate_randomizations: an array of their balance measures.

See Also

generate_randomizations for full randomization generation function. generate_randomizations_exact for the exact version.

Examples


## Not run: 
# Generate synthetic data 
X <- matrix(rnorm(100*5), 100, 5) # 5 covariates

# Generate 1000 randomizations for 100 units with 50 treated
rand_less_strict <- generate_randomizations_mc(
               n_units = 100, 
               n_treated = 50, 
               X = X, 
               randomization_accept_prob=0.01, 
               max_draws = 100000,
               batch_size = 1000)

# Use a stricter balance criterion
rand_more_strict <- generate_randomizations_mc(
               n_units = 100, 
               n_treated = 50, 
               X = X, 
               randomization_accept_prob=0.001, 
               max_draws = 1000000,
               batch_size = 1000)

## End(Not run)


[Package fastrerandomize version 0.2 Index]