bwdNN {dnn}R Documentation

Back propagation for dnn Models

Description

{bwdNN} is an R function for back propagation in DNN network.

Usage

#
# To apply back propagation in with a feed forward model 
#
# use 
#
   bwdNN(dy, cache, model)
#
# to calculate derivative of dL/dW

Arguments

dy

the derivative of the cost function with respect to the output layer of the fwdNN function.

cache

the cached output of fwdNN.

model

a model return from dNNmodel function.

Details

Here 'dy' plays an import role in the back propagation { bwdNN } since the probability model's loss function takes the output layer of the { dnn } (denote as yhat) as one of its parameter. Then 'dy' equals to the partial derivative of the loss function (-Log Likelihood) with respect to yhat, that is, dy = dL/d(yhat). For example, if the 'dnn' predicts the probability (yhat = p) for the mixture of two populations f1 and f2, then the likelihood function is f = p*f1 + (1-p)*f2, and the loss function is L = -log(p*f1+(1-p)*f2). Hence, dy = dL/dp = -(f1-f2)/f.

'cache' is the cache of each input layer generated from the { fwdNN } function.

The function { bwdCheck } calculates the numerical derivatives of dL/dW, which can be used to check if the back propagation is correct or not, see example below.

Value

A list contains the derivatives of weight parameter W is returned.

Author(s)

Bingshu E. Chen (bingshu.chen@queensu.ca)

See Also

dNNmodel, fwdNN, plot.dNNmodel, print.dNNmodel, summary.dNNmodel,

Examples

### define a dnn model, calculate the feed forward network
   model = dNNmodel(units = c(8, 6, 1), 
           activation = c("elu", "sigmoid", "sigmoid"), input_shape = 3)
   print(model)
   x = matrix(runif(15), nrow = 5, ncol = 3)
   cache = fwdNN(x, model)
   # dy = dL/dp, where L is the cost function such as the 
   # log-likehood and p is the output layer parameter of the DNN
   dy = as.matrix(runif(5, -0.1, 0.1), nrow = 5)  # a dummy dy for bwdNN input
   y  = predict(model, x) + dy
   
   # back propagation 
   dW = bwdNN(dy, cache, model)
   dw = bwdCheck(x, y, model)
   print(dW[[1]])
   print(dw[[1]])

[Package dnn version 0.0.6 Index]