predict.citodnn {cito}R Documentation

Predict from a fitted dnn model

Description

Predict from a fitted dnn model

Usage

## S3 method for class 'citodnn'
predict(object, newdata = NULL, type = c("link", "response"), ...)

Arguments

object

a model created by dnn

newdata

new data for predictions

type

link or response

...

additional arguments

Value

prediction matrix

Examples


if(torch::torch_is_installed()){
library(cito)

set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)

# Build and train  Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,])

# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
# MAE
mean(abs(predictions-iris[validation_set,]$Sepal.Length))
}


[Package cito version 1.0.0 Index]