model_predict {OptHoldoutSize} | R Documentation |
Make predictions
Description
Make predictions according to a given model
Usage
model_predict(
data_test,
trained_model,
return_type,
threshold = NULL,
model_family = NULL,
...
)
Arguments
data_test |
Data for which predictions are to be computed |
trained_model |
Model for which predictions are to be made |
return_type |
?? |
threshold |
?? |
model_family |
?? |
... |
Passed to function |
Value
Vector of predictions
Examples
## Set seed for reproducibility
seed=1234
set.seed(seed)
# Initialisation of patient data
n_iter <- 500 # Number of point estimates to be calculated
nobs <- 5000 # Number of observations, i.e patients
npreds <- 7 # Number of predictors
# Model family
family="log_reg"
# Baseline behaviour is an oracle Bayes-optimal predictor on only one variable
max_base_powers <- 1
base_vars=1
# Check the following holdout size fractions
frac_ho = 0.1
# Set ground truth coefficients, and the accuracy at baseline
coefs_general <- rnorm(npreds,sd=1/sqrt(npreds))
coefs_base <- gen_base_coefs(coefs_general, max_base_powers = max_base_powers)
# Generate dataset
X <- gen_preds(nobs, npreds)
# Generate labels
newdata <- gen_resp(X, coefs = coefs_general)
Y <- newdata$classes
# Combined dataset
pat_data <- cbind(X, Y)
pat_data$Y = factor(pat_data$Y)
# For each holdout size, split data into intervention and holdout set
mask <- split_data(pat_data, frac_ho)
data_interv <- pat_data[!mask,]
data_hold <- pat_data[mask,]
# Train model
trained_model <- model_train(data_hold, model_family = family)
thresh <- 0.5
# Make predictions
class_pred <- model_predict(data_interv, trained_model,
return_type = "class",
threshold = 0.5, model_family = family)
# Simulate baseline predictions
base_pred <- oracle_pred(data_hold,coefs_base[base_vars, ], num_vars = base_vars)
# Contingency table for model-based predictor (on intervention set)
print(table(class_pred,data_interv$Y))
# Contingency table for model-based predictor (on holdout set)
print(table(base_pred,data_hold$Y))
[Package OptHoldoutSize version 0.1.0.0 Index]