predict.rmda {robustDA}R Documentation

Prediction method for 'rmda' class objects.

Description

The prediction method for 'rmda' class objects which allows to predict the labels for test observations.

Usage

## S3 method for class 'rmda'
predict(object, X, ...)

Arguments

object

a supervised classifier genarated by the rmda function (a 'rmda' object).

X

the test data.

...

additional options for internal functions.

Value

A list with:

- cls: the predicted class labels,

- P: the posterior probabilities that observations belong to the classes.

Author(s)

Charles Bouveyron & Stéphane Girard

References

C. Bouveyron and S. Girard, Robust supervised classification with mixture models: Learning from data with uncertain labels, Pattern Recognition, vol. 42 (11), pp. 2649-2658, 2009.

Examples

set.seed(12345)

## Simulated data
N = 600
n = N/4
S1 = S2 = S3 = S4 = 2*diag(2)
m1 = 1.5*c(-4,0)
m4 = 1.5*c(0,-4)
m3 = 1.5*c(0,4)
m2 = 1.5*c(4,0)
Z.data = rbind(mvrnorm(n,m1,S1),mvrnorm(n,m2,S2),
  mvrnorm(n,m3,S3),mvrnorm(n,m4,S4))
Z.cls = c(rep(1,n),rep(1,n),rep(2,n),rep(2,n))

# Split in training and test sets
ind = sample(1:N,N)
X.data = Z.data[ind[1:(3*N/4)],]
X.cls = Z.cls[ind[1:(3*N/4)]]
Y.data = Z.data[ind[(3*N/4+1):N],]
Y.cls = Z.cls[ind[(3*N/4+1):N]]

## Adding noise label
cls = X.cls
nois = rbinom(length(cls),1,0.3)
lbl = cls
lbl[cls==1 & nois] = 2
lbl[cls==2 & nois] = 1

# Plot
par(mfrow=c(2,2))
plot(X.data,col=X.cls,pch=(18:19)[X.cls],
  main='Learning set with actual labels',xlab='',ylab='')
plot(X.data,col=lbl,pch=(18:19)[lbl],
  main='Learning set with noisy labels',xlab='',ylab='')


## Classification with LDA
c.lda = lda(X.data,lbl)
res.lda <- predict(c.lda,Y.data)$class

## Classification with MDA
c.mda = MclustDA(X.data,lbl,G=2)
res.mda = predict(c.mda,Y.data)$cl
plot(Y.data,col=res.mda,pch=(18:19)[res.mda],
     main='Classification of test set with MDA',xlab='',ylab='')

## Classification with RMDA
c.rmda <- rmda(X.data,lbl,K=4,model='VEV')
res.rmda <- predict(c.rmda,Y.data)
plot(Y.data,col=res.rmda$cls,pch=(18:19)[res.rmda$cls],
     main='Classification of test set with RMDA',xlab='',ylab='')

## Classification results
cat("* Correct classification rates on test data:\n")
cat("\tLDA:\t",sum(res.lda == Y.cls) / length(Y.cls),"\n")
cat("\tMDA:\t",sum(res.mda == Y.cls) / length(Y.cls),"\n")
cat("\tRMDA:\t",sum(res.rmda$cls == Y.cls) / length(Y.cls),"\n")

[Package robustDA version 1.2 Index]