predict_dt {dbnR} | R Documentation |
Performs inference over a test dataset with a GBN
Description
This function performs inference over each row of a folded data.table,
plots the results and gives metrics of the accuracy of the predictions. Given
that only a single row is predicted, the horizon of the prediction is at most 1.
This function is also called by the generic predict method for "dbn.fit"
objects. For long term forecasting, please refer to the
forecast_ts
function.
Usage
predict_dt(fit, dt, obj_nodes, verbose = T, look_ahead = F)
Arguments
fit |
the fitted bn |
dt |
the test dataset |
obj_nodes |
the nodes that are going to be predicted. They are all predicted at the same time |
verbose |
if TRUE, displays the metrics and plots the real values against the predictions |
look_ahead |
boolean that defines whether or not the values of the variables in t_0 should be used when predicting, even if they are not present in obj_nodes. This decides if look-ahead bias is introduced or not. |
Value
a data.table with the prediction results for each row
Examples
size = 3
data(motor)
dt_train <- motor[200:900]
dt_val <- motor[901:1000]
# With a DBN
obj <- c("pm_t_0")
net <- learn_dbn_struc(dt_train, size)
f_dt_train <- fold_dt(dt_train, size)
f_dt_val <- fold_dt(dt_val, size)
fit <- fit_dbn_params(net, f_dt_train, method = "mle-g")
res <- suppressWarnings(predict_dt(fit, f_dt_val, obj_nodes = obj, verbose = FALSE))
# With a Gaussian BN directly from bnlearn
obj <- c("pm")
net <- bnlearn::mmhc(dt_train)
fit <- bnlearn::bn.fit(net, dt_train, method = "mle-g")
res <- suppressWarnings(predict_dt(fit, dt_val, obj_nodes = obj, verbose = FALSE))