selfTraining {SSLR} | R Documentation |
General Interface for Self-training model
Description
Self-training is a simple and effective semi-supervised learning classification method. The self-training classifier is initially trained with a reduced set of labeled examples. Then it is iteratively retrained with its own most confident predictions over the unlabeled examples. Self-training follows a wrapper methodology using a base supervised classifier to establish the possible class of unlabeled instances.
Usage
selfTraining(learner, max.iter = 50, perc.full = 0.7, thr.conf = 0.5)
Arguments
learner |
model from parsnip package for training a supervised base classifier using a set of instances. This model need to have probability predictions (or optionally a distance matrix) and it's corresponding classes. |
max.iter |
maximum number of iterations to execute the self-labeling process. Default is 50. |
perc.full |
A number between 0 and 1. If the percentage of new labeled examples reaches this value the self-training process is stopped. Default is 0.7. |
thr.conf |
A number between 0 and 1 that indicates the confidence threshold.
At each iteration, only the newly labelled examples with a confidence greater than
this value ( |
Details
For predicting the most accurate instances per iteration, selfTraining
uses the predictions obtained with the learner specified. To train a model
using the learner
function, it is required a set of instances
(or a precomputed matrix between the instances if x.inst
parameter is FALSE
)
in conjunction with the corresponding classes.
Additionals parameters are provided to the learner
function via the
learner.pars
argument. The model obtained is a supervised classifier
ready to predict new instances through the pred
function.
Using a similar idea, the additional parameters to the pred
function
are provided using the pred.pars
argument. The pred
function returns
the probabilities per class for each new instance. The value of the
thr.conf
argument controls the confidence of instances selected
to enlarge the labeled set for the next iteration.
The stopping criterion is defined through the fulfillment of one of the following
criteria: the algorithm reaches the number of iterations defined in the max.iter
parameter or the portion of the unlabeled set, defined in the perc.full
parameter,
is moved to the labeled set. In some cases, the process stops and no instances
are added to the original labeled set. In this case, the user must assign a more
flexible value to the thr.conf
parameter.
Value
(When model fit) A list object of class "selfTraining" containing:
- model
The final base classifier trained using the enlarged labeled set.
- instances.index
The indexes of the training instances used to train the
model
. These indexes include the initial labeled instances and the newly labeled instances. Those indexes are relative tox
argument.- classes
The levels of
y
factor.- pred
The function provided in the
pred
argument.- pred.pars
The list provided in the
pred.pars
argument.
References
David Yarowsky.
Unsupervised word sense disambiguation rivaling supervised methods.
In Proceedings of the 33rd annual meeting on Association for Computational Linguistics,
pages 189-196. Association for Computational Linguistics, 1995.
Examples
library(tidyverse)
library(tidymodels)
library(caret)
library(SSLR)
data(wine)
set.seed(1)
train.index <- createDataPartition(wine$Wine, p = .7, list = FALSE)
train <- wine[ train.index,]
test <- wine[-train.index,]
cls <- which(colnames(wine) == "Wine")
#% LABELED
labeled.index <- createDataPartition(train$Wine, p = .2, list = FALSE)
train[-labeled.index,cls] <- NA
#We need a model with probability predictions from parsnip
#https://tidymodels.github.io/parsnip/articles/articles/Models.html
#It should be with mode = classification
#For example, with Random Forest
rf <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("randomForest")
m <- selfTraining(learner = rf,
perc.full = 0.7,
thr.conf = 0.5, max.iter = 10) %>% fit(Wine ~ ., data = train)
#Accuracy
predict(m,test) %>%
bind_cols(test) %>%
metrics(truth = "Wine", estimate = .pred_class)