predict.citodnn {cito}R Documentation

Predict from a fitted dnn model

Description

Predict from a fitted dnn model

Usage

## S3 method for class 'citodnn'
predict(
  object,
  newdata = NULL,
  type = c("link", "response", "class"),
  device = c("cpu", "cuda", "mps"),
  reduce = c("mean", "median", "none"),
  ...
)

## S3 method for class 'citodnnBootstrap'
predict(
  object,
  newdata = NULL,
  type = c("link", "response", "class"),
  device = c("cpu", "cuda", "mps"),
  reduce = c("mean", "median", "none"),
  ...
)

Arguments

object

a model created by dnn

newdata

new data for predictions

type

type of predictions. The default is on the scale of the linear predictor, "response" is on the scale of the response, and "class" means that class predictions are returned (if it is a classification task)

device

device on which network should be trained on.

reduce

predictions from bootstrapped model are by default reduced (mean, optional median or none)

...

additional arguments

Value

prediction matrix

Examples


if(torch::torch_is_installed()){
library(cito)

set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)

# Build and train  Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,])

# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
# MAE
mean(abs(predictions-iris[validation_set,]$Sepal.Length))
}


[Package cito version 1.1 Index]