s2netR {s2net} | R Documentation |
Trains a generalized extended linear joint trained model using semi-supervised data.
Description
This function is a wrapper for the class s2net
. It creates the C++ object and fits the model using input data
.
Usage
s2netR(data, params, loss = "default", frame = "ExtJT", proj = "auto",
fista = NULL, S3 = TRUE)
Arguments
data |
A |
params |
A |
loss |
Loss function. One of |
frame |
The semi-supervised frame: |
proj |
Should the unlabeled data be shifted to remove the model's effect? One of |
fista |
Fista setup parameters. An object of class |
S3 |
Boolean: should the method return an S3 object (default) or a C++ object? |
Value
Returns an object of S3 class s2netR
or a C++ object of class s2net
Author(s)
Juan C. Laria
References
Ryan, K. J., & Culp, M. V. (2015). On semi-supervised linear regression in covariate shift problems. The Journal of Machine Learning Research, 16(1), 3183-3217.
See Also
Examples
data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL, xU = auto_mpg$P1$xU)
model = s2netR(train,
s2Params(lambda1 = 0.1,
lambda2 = 0,
gamma1 = 0.1,
gamma2 = 100,
gamma3 = 0.1),
loss = "linear",
frame = "ExtJT",
proj = "auto",
fista = s2Fista(5000, 1e-7, 1, 0.8))
valid = s2Data(auto_mpg$P1$xU, auto_mpg$P1$yU, preprocess = train)
ypred = predict(model, valid$xL)
## Not run:
if(require(ggplot2)){
ggplot() +
aes(x = ypred, y = valid$yL) + geom_point() +
geom_abline(intercept = 0, slope = 1, linetype = 2)
}
## End(Not run)