mlapiEstimation {mlapi} | R Documentation |
Base abstract class for all classification/regression models
Description
Base class for all estimators. Defines minimal set of members and methods(with signatires) which have to be implemented in child classes.
Usage
mlapiEstimation
Format
R6Class
object.
Methods
$fit(x, y, ...)
$predict(x, ...)
Makes predictions on new data (after model was trained)
Arguments
- x
A matrix like object, should inherit from
Matrix
ormatrix
. Allowed classes should be defined in child classes.- y
target - usually
vector
, but also can be a matrix like object. Allowed classes should be defined in child classes.- ...
additional parameters with default values
Examples
SimpleLinearModel = R6::R6Class(
classname = "mlapiSimpleLinearModel",
inherit = mlapi::mlapiEstimation,
public = list(
initialize = function(tol = 1e-7) {
private$tol = tol
super$set_internal_matrix_formats(dense = "matrix", sparse = NULL)
},
fit = function(x, y, ...) {
x = super$check_convert_input(x)
stopifnot(is.vector(y))
stopifnot(is.numeric(y))
stopifnot(nrow(x) == length(y))
private$n_features = ncol(x)
private$coefficients = .lm.fit(x, y, tol = private$tol)[["coefficients"]]
},
predict = function(x) {
stopifnot(ncol(x) == private$n_features)
x %*% matrix(private$coefficients, ncol = 1)
}
),
private = list(
tol = NULL,
coefficients = NULL,
n_features = NULL
))
set.seed(1)
model = SimpleLinearModel$new()
x = matrix(sample(100 * 10, replace = TRUE), ncol = 10)
y = sample(c(0, 1), 100, replace = TRUE)
model$fit(as.data.frame(x), y)
res1 = model$predict(x)
# check pipe-compatible S3 interface
res2 = predict(x, model)
identical(res1, res2)
[Package mlapi version 0.1.1 Index]