predict.citodnn {cito} | R Documentation |
Predict from a fitted dnn model
Description
Predict from a fitted dnn model
Usage
## S3 method for class 'citodnn'
predict(
object,
newdata = NULL,
type = c("link", "response", "class"),
device = c("cpu", "cuda", "mps"),
reduce = c("mean", "median", "none"),
...
)
## S3 method for class 'citodnnBootstrap'
predict(
object,
newdata = NULL,
type = c("link", "response", "class"),
device = c("cpu", "cuda", "mps"),
reduce = c("mean", "median", "none"),
...
)
Arguments
object |
a model created by |
newdata |
new data for predictions |
type |
type of predictions. The default is on the scale of the linear predictor, "response" is on the scale of the response, and "class" means that class predictions are returned (if it is a classification task) |
device |
device on which network should be trained on. |
reduce |
predictions from bootstrapped model are by default reduced (mean, optional median or none) |
... |
additional arguments |
Value
prediction matrix
Examples
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,])
# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
# MAE
mean(abs(predictions-iris[validation_set,]$Sepal.Length))
}
[Package cito version 1.1 Index]