ART {ARTtransfer}R Documentation

ART: Adaptive and Robust Transfer Learning

Description

ART is a flexible framework for transfer learning that leverages information from auxiliary data sources to enhance model performance on primary tasks. It is designed to be robust against negative transfer by including the non-transfer model in the candidate pool, ensuring stable performance even when auxiliary datasets provide limited or no useful information. The ART framework supports both regression and classification tasks, aggregating predictions from multiple auxiliary models and the primary model using an adaptive exponential weighting mechanism to prevent negative transfer. Variable importance is also provided to indicate the contribution of each variable in the final model.

Usage

ART(
  X,
  y,
  X_aux,
  y_aux,
  X_test,
  func,
  lam = 1,
  maxit = 5000L,
  eps = 1e-06,
  type = c("regression", "classification"),
  is_coef = TRUE,
  importance = TRUE,
  ...
)

Arguments

X

A matrix for the primary dataset (target domain) predictors.

y

A vector for the primary dataset (target domain) responses.

X_aux

A list of matrices for the auxiliary datasets (source domains) predictors.

y_aux

A list of vectors for the auxiliary datasets (source domains) responses.

X_test

A matrix for the test dataset predictors.

func

A function used to fit the model on each dataset. The function must have the following signature: func(X, y, X_val, y_val, X_test, min_prod = 1e-5, max_prod = 1-1e-5, ...). The function should return a list with the following elements:

  • dev: The deviance (or loss) on the validation set if provided.

  • pred: The predictions on the test set if X_test is provided.

  • coef (optional): The model coefficients (only for regression models when is_coef = TRUE).

Pre-built wrapper functions, such as fit_lm, fit_logit, fit_glmnet_lm, fit_glmnet_logit, fit_random_forest, fit_gbm, and fit_nnet, can be used. Users may also provide their own model-fitting functions, but the input and output structure must follow the described signature and format.

lam

A regularization parameter for weighting the auxiliary models. Default is 1.

maxit

The maximum number of iterations for the model. Default is 5000.

eps

A convergence threshold for stopping the iterations. Default is 1e-6.

type

A string specifying the task type. Options are "regression" or "classification". Default is "regression".

is_coef

Logical; if TRUE, coefficients from the model are returned. Default is TRUE.

importance

Logical; if TRUE, variable importance is calculated. Only applicable if 'is_coef' is TRUE. Default is TRUE.

...

Additional arguments passed to the model-fitting function.

Details

The ART function performs adaptive and robust transfer learning by iteratively combining predictions from the primary dataset and auxiliary datasets. It updates the weights of each dataset's predictions through an aggregation process, eventually yielding a final set of predictions based on weighted contributions from the source and target models.

The auxiliary datasets ('X_aux' and 'y_aux') must be provided as lists, with each element corresponding to a dataset from a different source domain.

Value

A list containing:

pred_ART

The predictions for the test dataset.

coef_ART

The coefficients of the final model, if 'is_coef' is TRUE.

W_ART

The final weights for each dataset (including the primary dataset).

iter_ART

The number of iterations performed until convergence.

VI_ART

The variable importance, if 'importance' is TRUE.

Examples

# Example usage
dat <- generate_data(n0=50, K=3, nk=50, K_noise=2, nk_noise=30, p=10, 
       mu_trgt=1, xi_aux=0.5, ro=0.5, err_sig=1)
fit <- ART(dat$X, dat$y, dat$X_aux, dat$y_aux, dat$X_test, func=fit_lm, lam=1, type="regression")


[Package ARTtransfer version 1.0.0 Index]