predict.citodnn {cito} | R Documentation |
Predict from a fitted dnn model
## S3 method for class 'citodnn'
predict(object, newdata = NULL, type = c("link", "response"), ...)
object |
a model created by |
newdata |
new data for predictions |
type |
link or response |
... |
additional arguments |
prediction matrix
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,])
# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
# MAE
mean(abs(predictions-iris[validation_set,]$Sepal.Length))
}