predict_diagram_ksvm {TDApplied}R Documentation

Predict the outcome labels for a list of persistence diagrams using a pre-trained diagram ksvm model.

Description

Returns the predicted response vector of the model on the new diagrams.

Usage

predict_diagram_ksvm(
  new_diagrams,
  model,
  K = NULL,
  num_workers = parallelly::availableCores(omit = 1)
)

Arguments

new_diagrams

a list of persistence diagrams which are either the output of a persistent homology calculation like ripsDiag/calculate_homology/PyH, or diagram_to_df. Only one of 'new_diagrams' and 'K' need to be supplied.

model

the output of a diagram_ksvm function call, of class 'diagram_ksvm'.

K

an optional cross-Gram matrix of the new diagrams and the diagrams in 'model', default NULL. If not NULL then 'new_diagrams' does not need to be supplied.

num_workers

the number of cores used for parallel computation, default is one less than the number of cores on the machine.

Details

This function is a wrapper of the kernlab predict function.

Value

a vector containing the output of predict.ksvm on the cross Gram matrix of the new diagrams and the support vector diagrams stored in the model.

Author(s)

Shael Brown - shaelebrown@gmail.com

See Also

diagram_ksvm for training a SVM model on a training set of persistence diagrams and labels.

Examples


if(require("TDAstats"))
{
  # create four diagrams
  D1 <- TDAstats::calculate_homology(TDAstats::circle2d[sample(1:100,20),],
                      dim = 1,threshold = 2)
  D2 <- TDAstats::calculate_homology(TDAstats::circle2d[sample(1:100,20),],
                      dim = 1,threshold = 2)
  D3 <- TDAstats::calculate_homology(TDAstats::sphere3d[sample(1:100,20),],
                      dim = 1,threshold = 2)
  D4 <- TDAstats::calculate_homology(TDAstats::sphere3d[sample(1:100,20),],
                      dim = 1,threshold = 2)
  g <- list(D1,D2,D3,D4)

  # create response vector
  y <- as.factor(c("circle","circle","sphere","sphere"))

  # fit model without cross validation
  model_svm <- diagram_ksvm(diagrams = g,cv = 1,dim = c(0),
                            y = y,sigma = c(1),t = c(1),
                            num_workers = 2)

  # create two new diagrams
  D5 <- TDAstats::calculate_homology(TDAstats::circle2d[sample(1:100,20),],
                      dim = 1,threshold = 2)
  D6 <- TDAstats::calculate_homology(TDAstats::sphere3d[sample(1:100,20),],
                      dim = 1,threshold = 2)
  g_new <- list(D5,D6)

  # predict with precomputed Gram matrix
  K <- gram_matrix(diagrams = g_new,other_diagrams = model_svm$diagrams,
                   dim = model_svm$best_model$dim,sigma = model_svm$best_model$sigma,
                   t = model_svm$best_model$t,num_workers = 2)
  predict_diagram_ksvm(K = K,model = model_svm,num_workers = 2)
}

[Package TDApplied version 3.0.3 Index]