predict.DNAmf {DNAmf}R Documentation

Predictive posterior mean and variance for DNAmf object with nonseparable kernel.

Description

The function computes the predictive posterior mean and variance for the DNAmf model using closed-form expressions based on the chosen nonseparable kernel at given new input locations.

Usage

## S3 method for class 'DNAmf'
predict(object, x, targett = 0, nimpute = 50, ...)

Arguments

object

A fitted DNAmf object.

x

A vector or matrix of new input locations to predict.

targett

A numeric value of target tuning parameter to predict.

nimpute

Number of imputations for non-nested designs. Default is 50.

...

Additional arguments for compatibility with generic method predict.

Details

The predict.DNAmf function internally calls closed_form, which further calls h1_sqex, h2_sqex, h2_sqex_single for kernel="sqex", or h1_matern, h2_matern, h2_matern_single for kernel="matern1.5" orkernel="matern2.5", to recursively compute the closed-form posterior mean and variance at each level.

From the fitted model from DNAmf, the posterior mean and variance are calculated based on the closed-form expression derived by a recursive fashion. The formulas depend on its kernel choices.

If the fitted model was constructed with non-nested designs (nested=FALSE), the function generates nimpute sets of imputations for pseudo outputs via imputer.

For further details, see Heo, Boutelet, and Sung (2025+, <arXiv:2506.08328>).

Value

A list of predictive posterior mean and variance for each level and computation time containing:

See Also

DNAmf for the user-level function.

Examples

### Non-Additive example ###
library(RNAmf)

### Non-Additive Function ###
fl <- function(x, t){
  term1 <- sin(10 * pi * x / (5+t))
  term2 <- 0.2 * sin(8 * pi * x)
  term1 + term2
}

### training data ###
n1 <- 13; n2 <- 10; n3 <- 7; n4 <- 4; n5 <- 1;
m1 <- 2.5; m2 <- 2.0; m3 <- 1.5; m4 <- 1.0; m5 <- 0.5;
d <- 1
eps <- sqrt(.Machine$double.eps)
x <- seq(0,1,0.01)

### fix seed to reproduce the result ###
set.seed(1)

### generate initial nested design ###
NestDesign <- NestedX(c(n1,n2,n3,n4,n5),d)

X1 <- NestDesign[[1]]
X2 <- NestDesign[[2]]
X3 <- NestDesign[[3]]
X4 <- NestDesign[[4]]
X5 <- NestDesign[[5]]

y1 <- fl(X1, t=m1)
y2 <- fl(X2, t=m2)
y3 <- fl(X3, t=m3)
y4 <- fl(X4, t=m4)
y5 <- fl(X5, t=m5)

### fit a DNAmf ###
fit.DNAmf <- DNAmf(X=list(X1, X2, X3, X4, X5), y=list(y1, y2, y3, y4, y5), kernel="sqex",
                   t=c(m1,m2,m3,m4,m5), multi.start=10, constant=TRUE)

### predict ###
pred.DNAmf <- predict(fit.DNAmf, x, targett=0)
predydiffu <- pred.DNAmf$mu
predsig2diffu <- pred.DNAmf$sig2

### RMSE ###
print(sqrt(mean((predydiffu-fl(x, t=0))^2))) # 0.1162579

### visualize the emulation performance ###
oldpar <- par(mfrow = c(2,3))
create_plot_base <- function(i, mesh_size, x, pred_mu, pred_sig2,
                             X_points = NULL, y_points = NULL, add_points = TRUE, yylim) {
  lower <- pred_mu - qnorm(0.995) * sqrt(pred_sig2)
  upper <- pred_mu + qnorm(0.995) * sqrt(pred_sig2)

  plot(x, pred_mu, type = "n", ylim = c(-yylim, yylim), xlab = "", ylab = "",
       main = paste0("Mesh size = ", mesh_size), axes = FALSE)
  box()

  polygon(c(x, rev(x)), c(upper, rev(lower)),
          col = adjustcolor("blue", alpha.f = 0.2), border = NA)
  lines(x, pred_mu, col = "blue", lwd = 2)
  lines(x, fl(x, mesh_size), lty = 2, col = "black", lwd = 2)

  if (add_points && !is.null(X_points) && !is.null(y_points)) {
    points(X_points, y_points, col = "red", pch = 16, cex = 1.3)
  }
}

mesh_sizes <- c(m1, m2, m3, m4, m5, 0)
mu_list <- list(pred.DNAmf$mu_1, pred.DNAmf$mu_2, pred.DNAmf$mu_3,
                pred.DNAmf$mu_4, pred.DNAmf$mu_5, pred.DNAmf$mu)
sig2_list <- list(pred.DNAmf$sig2_1, pred.DNAmf$sig2_2, pred.DNAmf$sig2_3,
                  pred.DNAmf$sig2_4, pred.DNAmf$sig2_5, pred.DNAmf$sig2)
X_list <- list(X1, X2, X3, X4, X5, NULL)
y_list <- list(y1, y2, y3, y4, y5, NULL)

plots <- mapply(function(i, m, mu, sig2, X, y) {
  create_plot_base(i, m, x, mu, sig2, X, y, add_points = !is.null(X), yylim=1.5)
}, i = 1:6, m = mesh_sizes, mu = mu_list, sig2 = sig2_list,
X = X_list, y = y_list, SIMPLIFY = FALSE)
par(oldpar)


[Package DNAmf version 0.1.0 Index]