f_fit_gradient_01 {mactivate} | R Documentation |
Fit Multivariate Regression Model with mactivate Using Gradient Descent
Description
Use simple gradient descent to locate model parameters, i.e., primary effects, multiplicative effects, and activation parameters, W
.
Usage
f_fit_gradient_01(
X,
y,
m_tot,
U = NULL,
m_start = 1,
mact_control = f_control_mactivate(),
verbosity = 2)
Arguments
X |
Numerical matrix, |
y |
Numerical vector of length |
m_tot |
Scalar non-negative integer. Total number of columns of activation layer, |
U |
Numerical matrix, |
m_start |
Currently not used. |
mact_control |
Named list of class |
verbosity |
Scalar integer. |
Details
Please make sure to read Details in f_dmss_dW
help page before using this function.
Value
An unnamed list of class mactivate_fit_gradient_01
of length m_tot + 1
. Each node is a named list containing fitted parameter estimates. The first top-level node of this object contains parameter estimates when fitting ‘primary effects’ only (W
has no columns), the second, parameter estimates for fitting with 1 column of W
, and so on.
See Also
Essentially equivalent to, but likely slower than: f_fit_hybrid_01
. See f_fit_gradient_logistic_01
for logistic data (binomial response).
Examples
xxnow <- Sys.time()
library(mactivate)
set.seed(777)
d <- 4
N <- 2000
X <- matrix(rnorm(N*d, 0, 1), N, d) ####
colnames(X) <- paste0("x", I(1:d))
############# primary effects
b <- rep_len( c(-1/2, 1/2), d )
###########
xxA <- (X[ , 1]+1/3) * (X[ , 2]-1/3)
#xxA <- (X[ , 1]+0/3) * (X[ , 2]-0/3)
ystar <-
X %*% b +
2 * xxA
m_tot <- 4
#############
xs2 <- "y ~ . "
xtrue_formula <- eval(parse(text=xs2))
xnoint_formula <- eval(parse(text="y ~ . - xxA"))
yerrs <- rnorm(N, 0, 3)
y <- ystar + yerrs
## y <- (y - mean(y)) / sd(y)
########## standardize X
Xall <- t( ( t(X) - apply(X, 2, mean) ) / apply(X, 2, sd) )
yall <- y
Nall <- N
####### fold index
xxfoldNumber <- rep_len(1:2, N)
ufolds <- sort(unique(xxfoldNumber)) ; ufolds
############### predict
############### predict
dfx <- data.frame("y"=yall, Xall, xxA)
tail(dfx)
################### incorrectly fit LM: no interactions
xlm <- lm(xnoint_formula , data=dfx)
summary(xlm)
yhat <- predict(xlm, newdata=dfx)
sqrt( mean( (yall - yhat)^2 ) )
################### correctly fit LM
xlm <- lm(xtrue_formula, data=dfx)
summary(xlm)
yhat <- predict(xlm, newdata=dfx)
sqrt( mean( (yall - yhat)^2 ) )
################ fit using gradient m-activation
######
m_tot <- 4
xcmact_gradient <-
f_control_mactivate(
param_sensitivity = 10^11,
bool_free_w = TRUE,
w0_seed = 0.05,
w_col_search = "alternate",
bool_headStart = TRUE,
ss_stop = 10^(-12), ###
escape_rate = 1.02, #### 1.0002,
Wadj = 1/1,
force_tries = 0,
lambda = 0
)
#### Fit
Uall <- Xall
head(Uall)
xthis_fold <- ufolds[ 1 ]
xndx_test <- which( xxfoldNumber %in% xthis_fold )
xndx_train <- setdiff( 1:Nall, xndx_test )
X_train <- Xall[ xndx_train, , drop=FALSE ]
y_train <- yall[ xndx_train ]
U_train <- Uall[ xndx_train, , drop=FALSE ]
xxls_out <-
f_fit_gradient_01(
X = X_train,
y = y_train,
m_tot = m_tot,
U = U_train,
m_start = 1,
mact_control = xcmact_gradient,
verbosity = 0
)
######### check test error
U_test <- Uall[ xndx_test, , drop=FALSE ]
X_test <- Xall[ xndx_test, , drop=FALSE ]
y_test <- yall[ xndx_test ]
yhatTT <- matrix(NA, length(xndx_test), m_tot+1)
for(iimm in 0:m_tot) {
yhat_fold <- predict(object=xxls_out, X0=X_test, U0=U_test, mcols=iimm )
yhatTT[ , iimm + 1 ] <- yhat_fold
}
errs_by_m <- NULL
for(iimm in 1:ncol(yhatTT)) {
yhatX <- yhatTT[ , iimm]
errs_by_m[ iimm ] <- sqrt(mean( (y_test - yhatX)^2 ))
cat(iimm, "::", errs_by_m[ iimm ])
}
##### plot test RMSE vs m
plot(0:(length(errs_by_m)-1), errs_by_m, type="l", xlab="m", ylab="RMSE Cost")
##################
xtrue_formula_use <- xtrue_formula
xlm <- lm(xnoint_formula , data=dfx[ xndx_train, ])
yhat <- predict(xlm, newdata=dfx[ xndx_test, ])
cat("\n\n", "No interaction model RMSE:", sqrt( mean( (y_test - yhat)^2 ) ), "\n")
xlm <- lm(xtrue_formula_use , data=dfx[ xndx_train, ])
yhat <- predict(xlm, newdata=dfx[ xndx_test, ])
cat("\n\n", "'true' model RMSE:", sqrt( mean( (y_test - yhat)^2 ) ), "\n")
cat( "Runtime:", difftime(Sys.time(), xxnow, units="secs"), "\n" )