predict.RNAmf {RNAmf} | R Documentation |
prediction of the RNAmf emulator with 2 or 3 fidelity levels.
Description
The function computes the posterior mean and variance of RNA models with two or three fidelity levels
by fitted model using RNAmf_two_level
or RNAmf_three_level
.
Usage
## S3 method for class 'RNAmf'
predict(object, x, ...)
Arguments
object |
a class |
x |
vector or matrix of new input locations to predict. |
... |
for compatibility with generic method |
Details
From the model fitted by RNAmf_two_level
or RNAmf_three_level
,
the posterior mean and variance are calculated based on the closed form expression derived by a recursive fashion.
The formulas depend on its kernel choices.
For details, see Heo and Sung (2024, <doi:10.1080/00401706.2024.2376173>).
Value
-
mu
: vector of predictive posterior mean. -
sig2
: vector of predictive posterior variance. -
time
: a scalar of the time for the computation.
See Also
RNAmf_two_level
or RNAmf_three_level
for the model.
Examples
### two levels example ###
library(lhs)
### Perdikaris function ###
f1 <- function(x) {
sin(8 * pi * x)
}
f2 <- function(x) {
(x - sqrt(2)) * (sin(8 * pi * x))^2
}
### training data ###
n1 <- 13
n2 <- 8
### fix seed to reproduce the result ###
set.seed(1)
### generate initial nested design ###
X <- NestedX(c(n1, n2), 1)
X1 <- X[[1]]
X2 <- X[[2]]
### n1 and n2 might be changed from NestedX ###
### assign n1 and n2 again ###
n1 <- nrow(X1)
n2 <- nrow(X2)
y1 <- f1(X1)
y2 <- f2(X2)
### n=100 uniform test data ###
x <- seq(0, 1, length.out = 100)
### fit an RNAmf ###
fit.RNAmf <- RNAmf_two_level(X1, y1, X2, y2, kernel = "sqex")
### predict ###
predy <- predict(fit.RNAmf, x)$mu
predsig2 <- predict(fit.RNAmf, x)$sig2
### RMSE ###
print(sqrt(mean((predy - f2(x))^2)))
### visualize the emulation performance ###
plot(x, predy,
type = "l", lwd = 2, col = 3, # emulator and confidence interval
ylim = c(-2, 1)
)
lines(x, predy + 1.96 * sqrt(predsig2 * length(y2) / (length(y2) - 2)), col = 3, lty = 2)
lines(x, predy - 1.96 * sqrt(predsig2 * length(y2) / (length(y2) - 2)), col = 3, lty = 2)
curve(f2(x), add = TRUE, col = 1, lwd = 2, lty = 2) # high fidelity function
points(X1, y1, pch = 1, col = "red") # low-fidelity design
points(X2, y2, pch = 4, col = "blue") # high-fidelity design
### three levels example ###
### Branin function ###
branin <- function(xx, l){
x1 <- xx[1]
x2 <- xx[2]
if(l == 1){
10*sqrt((-1.275*(1.2*x1+0.4)^2/pi^2+5*(1.2*x1+0.4)/pi+(1.2*x2+0.4)-6)^2 +
(10-5/(4*pi))*cos((1.2*x1+0.4))+ 10) + 2*(1.2*x1+1.9) - 3*(3*(1.2*x2+2.4)-1) - 1 - 3*x2 + 1
}else if(l == 2){
10*sqrt((-1.275*(x1+2)^2/pi^2+5*(x1+2)/pi+(x2+2)-6)^2 +
(10-5/(4*pi))*cos((x1+2))+ 10) + 2*(x1-0.5) - 3*(3*x2-1) - 1
}else if(l == 3){
(-1.275*x1^2/pi^2+5*x1/pi+x2-6)^2 + (10-5/(4*pi))*cos(x1)+ 10
}
}
output.branin <- function(x, l){
factor_range <- list("x1" = c(-5, 10), "x2" = c(0, 15))
for(i in 1:length(factor_range)) x[i] <- factor_range[[i]][1] + x[i] * diff(factor_range[[i]])
branin(x[1:2], l)
}
### training data ###
n1 <- 20; n2 <- 15; n3 <- 10
### fix seed to reproduce the result ###
set.seed(1)
### generate initial nested design ###
X <- NestedX(c(n1, n2, n3), 2)
X1 <- X[[1]]
X2 <- X[[2]]
X3 <- X[[3]]
### n1, n2 and n3 might be changed from NestedX ###
### assign n1, n2 and n3 again ###
n1 <- nrow(X1)
n2 <- nrow(X2)
n3 <- nrow(X3)
y1 <- apply(X1,1,output.branin, l=1)
y2 <- apply(X2,1,output.branin, l=2)
y3 <- apply(X3,1,output.branin, l=3)
### n=10000 grid test data ###
x <- as.matrix(expand.grid(seq(0, 1, length.out = 100),seq(0, 1, length.out = 100)))
### fit an RNAmf ###
fit.RNAmf <- RNAmf_three_level(X1, y1, X2, y2, X3, y3, kernel = "sqex")
### predict ###
pred.RNAmf <- predict(fit.RNAmf, x)
predy <- pred.RNAmf$mu
predsig2 <- pred.RNAmf$sig2
### RMSE ###
print(sqrt(mean((predy - apply(x,1,output.branin, l=3))^2)))
### visualize the emulation performance ###
x1 <- x2 <- seq(0, 1, length.out = 100)
oldpar <- par(mfrow=c(1,2))
image(x1, x2, matrix(apply(x,1,output.branin, l=3), ncol=100),
zlim=c(0,310), main="Branin function")
image(x1, x2, matrix(predy, ncol=100),
zlim=c(0,310), main="RNAmf prediction")
par(oldpar)
### predictive variance ###
print(predsig2)