predict.DNAmf {DNAmf} | R Documentation |
Predictive posterior mean and variance for DNAmf object with nonseparable kernel.
Description
The function computes the predictive posterior mean and variance for the DNAmf model using closed-form expressions based on the chosen nonseparable kernel at given new input locations.
Usage
## S3 method for class 'DNAmf'
predict(object, x, targett = 0, nimpute = 50, ...)
Arguments
object |
A fitted DNAmf object. |
x |
A vector or matrix of new input locations to predict. |
targett |
A numeric value of target tuning parameter to predict. |
nimpute |
Number of imputations for non-nested designs. Default is 50. |
... |
Additional arguments for compatibility with generic method |
Details
The predict.DNAmf
function internally calls closed_form
,
which further calls h1_sqex
, h2_sqex
, h2_sqex_single
for kernel="sqex"
,
or h1_matern
, h2_matern
, h2_matern_single
for kernel="matern1.5"
orkernel="matern2.5"
,
to recursively compute the closed-form posterior mean and variance at each level.
From the fitted model from DNAmf
,
the posterior mean and variance are calculated based on the closed-form expression derived by a recursive fashion.
The formulas depend on its kernel choices.
If the fitted model was constructed with non-nested designs (nested=FALSE
),
the function generates nimpute
sets of imputations for pseudo outputs
via imputer
.
For further details, see Heo, Boutelet, and Sung (2025+, <arXiv:2506.08328>).
Value
A list of predictive posterior mean and variance for each level and computation time containing:
-
mu_1
,sig2_1
, ...,mu_L
,sig2_L
: A vector of predictive posterior mean and variance at each level. -
mu
: A vector of predictive posterior mean at target tuning parameter. -
sig2
: A vector of predictive posterior variance at target tuning parametertargett
. -
time
: Total computation time in seconds.
See Also
DNAmf
for the user-level function.
Examples
### Non-Additive example ###
library(RNAmf)
### Non-Additive Function ###
fl <- function(x, t){
term1 <- sin(10 * pi * x / (5+t))
term2 <- 0.2 * sin(8 * pi * x)
term1 + term2
}
### training data ###
n1 <- 13; n2 <- 10; n3 <- 7; n4 <- 4; n5 <- 1;
m1 <- 2.5; m2 <- 2.0; m3 <- 1.5; m4 <- 1.0; m5 <- 0.5;
d <- 1
eps <- sqrt(.Machine$double.eps)
x <- seq(0,1,0.01)
### fix seed to reproduce the result ###
set.seed(1)
### generate initial nested design ###
NestDesign <- NestedX(c(n1,n2,n3,n4,n5),d)
X1 <- NestDesign[[1]]
X2 <- NestDesign[[2]]
X3 <- NestDesign[[3]]
X4 <- NestDesign[[4]]
X5 <- NestDesign[[5]]
y1 <- fl(X1, t=m1)
y2 <- fl(X2, t=m2)
y3 <- fl(X3, t=m3)
y4 <- fl(X4, t=m4)
y5 <- fl(X5, t=m5)
### fit a DNAmf ###
fit.DNAmf <- DNAmf(X=list(X1, X2, X3, X4, X5), y=list(y1, y2, y3, y4, y5), kernel="sqex",
t=c(m1,m2,m3,m4,m5), multi.start=10, constant=TRUE)
### predict ###
pred.DNAmf <- predict(fit.DNAmf, x, targett=0)
predydiffu <- pred.DNAmf$mu
predsig2diffu <- pred.DNAmf$sig2
### RMSE ###
print(sqrt(mean((predydiffu-fl(x, t=0))^2))) # 0.1162579
### visualize the emulation performance ###
oldpar <- par(mfrow = c(2,3))
create_plot_base <- function(i, mesh_size, x, pred_mu, pred_sig2,
X_points = NULL, y_points = NULL, add_points = TRUE, yylim) {
lower <- pred_mu - qnorm(0.995) * sqrt(pred_sig2)
upper <- pred_mu + qnorm(0.995) * sqrt(pred_sig2)
plot(x, pred_mu, type = "n", ylim = c(-yylim, yylim), xlab = "", ylab = "",
main = paste0("Mesh size = ", mesh_size), axes = FALSE)
box()
polygon(c(x, rev(x)), c(upper, rev(lower)),
col = adjustcolor("blue", alpha.f = 0.2), border = NA)
lines(x, pred_mu, col = "blue", lwd = 2)
lines(x, fl(x, mesh_size), lty = 2, col = "black", lwd = 2)
if (add_points && !is.null(X_points) && !is.null(y_points)) {
points(X_points, y_points, col = "red", pch = 16, cex = 1.3)
}
}
mesh_sizes <- c(m1, m2, m3, m4, m5, 0)
mu_list <- list(pred.DNAmf$mu_1, pred.DNAmf$mu_2, pred.DNAmf$mu_3,
pred.DNAmf$mu_4, pred.DNAmf$mu_5, pred.DNAmf$mu)
sig2_list <- list(pred.DNAmf$sig2_1, pred.DNAmf$sig2_2, pred.DNAmf$sig2_3,
pred.DNAmf$sig2_4, pred.DNAmf$sig2_5, pred.DNAmf$sig2)
X_list <- list(X1, X2, X3, X4, X5, NULL)
y_list <- list(y1, y2, y3, y4, y5, NULL)
plots <- mapply(function(i, m, mu, sig2, X, y) {
create_plot_base(i, m, x, mu, sig2, X, y, add_points = !is.null(X), yylim=1.5)
}, i = 1:6, m = mesh_sizes, mu = mu_list, sig2 = sig2_list,
X = X_list, y = y_list, SIMPLIFY = FALSE)
par(oldpar)