distCompare {WpProj} | R Documentation |
Compares Optimal Transport Distances Between WpProj and Original Models
Description
Will compare the Wasserstein distance between the original model and the WpProj
model.
Usage
distCompare(
models,
target = list(parameters = NULL, predictions = NULL),
power = 2,
method = "exact",
quantity = c("parameters", "predictions"),
parallel = NULL,
transform = function(x) {
return(x)
},
...
)
Arguments
models |
A list of models from WpProj methods |
target |
The target to compare the methods to. Should be a list with slots "parameters" to compare the parameters and "predictions" to compare predictions |
power |
The power parameter of the Wasserstein distance. |
method |
Which approximation to the Wasserstein distance to use. Should be one of the outputs of |
quantity |
Should the function target the "parameters" or the "predictions". Can choose both. |
parallel |
Parallel backend to use for the |
transform |
Transformation function for the predictions. |
... |
other options passed to the |
Details
For the data frames, dist
is the Wasserstein distance, nactive
is the number of active variables in the model, groups
is the name distinguishing the model, and method
is the method used to calculate the distance (i.e., exact, sinkhorn, etc.). If the list in models
is named, these will be used as the group names otherwise the group names will be created based on the call from the WpProj
method.
Value
an object of class distcompare
with slots parameters
, predictions
, and p
. The slots parameters
and predictions
are data frames. See the details for more info. The slot p
is the power parameter of the Wasserstein distance used in the distance calculation.
Examples
if(rlang::is_installed("stats")) {
n <- 32
p <- 10
s <- 21
x <- matrix( stats::rnorm( p * n ), nrow = n, ncol = p )
beta <- (1:10)/10
y <- x %*% beta + stats::rnorm(n)
post_beta <- matrix(beta, nrow=p, ncol=s) + stats::rnorm(p*s, 0, 0.1)
post_mu <- x %*% post_beta
fit1 <- WpProj(X=x, eta=post_mu, power = 2.0,
options = list(penalty = "lasso")
)
fit2 <- WpProj(X=x, eta=post_mu, theta = post_beta, power = 2.0,
method = "binary program", solver = "lasso",
options = list(solver.options = list(penalty = "mcp"))
)
dc <- distCompare(models = list("L1" = fit1, "BP" = fit2),
target = list(parameters = post_beta, predictions = post_mu))
plot(dc)
}