predict.smm {mildsvm} | R Documentation |
Predict method for smm
object
Description
Predict method for smm
object
Usage
## S3 method for class 'smm'
predict(
object,
new_data,
type = c("class", "raw"),
layer = "instance",
new_instances = "instance_name",
new_bags = "bag_name",
kernel = NULL,
...
)
Arguments
object |
an object of class |
new_data |
A data frame to predict from. This needs to have all of the features that the data was originally fitted with. |
type |
If |
layer |
If |
new_instances |
A character or character vector. Can specify a singular
character that provides the column name for the instance names in
|
new_bags |
A character or character vector. Only relevant when fit with
|
kernel |
An optional pre-computed kernel matrix at the instance level or
|
... |
Arguments passed to or from other methods. |
Details
When the object was fitted using the formula
method, then the parameters
new_bags
and new_instances
are not necessary, as long as the names match
the original function call.
Value
tibble with nrow(new_data)
rows. If type = 'class'
, the tibble
will have a column named .pred_class
. If type = 'raw'
, the tibble will
have a column name .pred
.
Author(s)
Sean Kent
See Also
smm()
for fitting the smm
object.
Examples
set.seed(8)
n_instances <- 10
n_samples <- 20
y <- rep(c(1, -1), each = n_samples * n_instances / 2)
instances <- as.character(rep(1:n_instances, each = n_samples))
x <- data.frame(x1 = rnorm(length(y), mean = 1*(y==1)),
x2 = rnorm(length(y), mean = 2*(y==1)),
x3 = rnorm(length(y), mean = 3*(y==1)))
mdl <- smm(x, y, instances, control = list(sigma = 1/3))
# instance level predictions (training data)
suppressWarnings(library(dplyr))
data.frame(instance_name = instances, y = y, x) %>%
bind_cols(predict(mdl, type = "raw", new_data = x, new_instances = instances)) %>%
bind_cols(predict(mdl, type = "class", new_data = x, new_instances = instances)) %>%
distinct(instance_name, y, .pred, .pred_class)
# test data
new_inst <- rep(c("11", "12"), each = 30)
new_y <- rep(c(1, -1), each = 30)
new_x <- data.frame(x1 = rnorm(length(new_inst), mean = 1*(new_inst=="11")),
x2 = rnorm(length(new_inst), mean = 2*(new_inst=="11")),
x3 = rnorm(length(new_inst), mean = 3*(new_inst=="11")))
# instance level predictions (test data)
data.frame(instance_name = new_inst, y = new_y, new_x) %>%
bind_cols(predict(mdl, type = "raw", new_data = new_x, new_instances = new_inst)) %>%
bind_cols(predict(mdl, type = "class", new_data = new_x, new_instances = new_inst)) %>%
distinct(instance_name, y, .pred, .pred_class)