cv.tropsvm {Rtropical}R Documentation

Cross-Validation for Tropical Support Vector Machines

Description

Conduct k-fold cross validation for tropsvm and return an object "cv.tropsvm".

Usage

cv.tropsvm(x, y, parallel = FALSE, nfold = 10, nassignment = 10, ncores = 2)

Arguments

x

a data matrix, of dimension nobs x nvars; each row is an observation vector.

y

a response vector with one label for each row/component of x.

parallel

a logical value indicating if parallel computing should be used. (default: FALSE)

nfold

a numeric value of the number of data folds for cross-validation. (default: 10)

nassignment

a numeric value indicating the size of the parameter grid of assignments. (default: 10)

ncores

a numeric value indicating the number of threads utilized for multi-cored CPUs. (default: 2)

Value

object with S3 class cv.tropsvm containing the fitted model, including:

apex

The negative apex of the fitted optimal tropical hyperplane.

assignment

The best assignment tuned by cross-validation.

index

The best classification method tuned by cross-validation.

levels

The name of each category, consistent with categories in y.

accuracy

The validation accuracy for each fold.

nfold

The number of folds used in cross-validation.

See Also

summary, predict, coef and the tropsvm function.

Examples


# data generation
library(Rfast)
set.seed(101)
e <- 20
n <- 10
N <- 10
s <- 5
x <- rbind(
  rmvnorm(n, mu = c(5, -5, rep(0, e - 2)), sigma = diag(s, e)),
  rmvnorm(n, mu = c(-5, 5, rep(0, e - 2)), sigma = diag(s, e))
)
y <- as.factor(c(rep(1, n), rep(2, n)))
newx <- rbind(
  rmvnorm(N, mu = c(5, -5, rep(0, e - 2)), sigma = diag(s, e)),
  rmvnorm(N, mu = c(-5, 5, rep(0, e - 2)), sigma = diag(s, e))
)
newy <- as.factor(rep(c(1, 2), each = N))

# train the tropical svm
cv_tropsvm_fit <- cv.tropsvm(x, y, parallel = FALSE)

summary(cv_tropsvm_fit)
coef(cv_tropsvm_fit)

# test with new data
pred <- predict(cv_tropsvm_fit, newx)

# check with accuracy
table(pred, newy)

# compute testing accuracy
sum(pred == newy) / length(newy)

[Package Rtropical version 1.2.1 Index]