DMTL {DMTL}R Documentation

Distribution Mapping based Transfer Learning

Description

This function performs distribution mapping based transfer learning (DMTL) regression for given target (primary) and source (secondary) datasets. The data available in the source domain are used to design an appropriate predictive model. The target features with unknown response values are transferred to the source domain via distribution matching and then the corresponding response values in the source domain are predicted using the aforementioned predictive model. The response values are then transferred to the original target space by applying distribution matching again. Hence, this function needs an unmatched pair of target datasets (features and response values) and a matched pair of source datasets.

Usage

DMTL(
  target_set,
  source_set,
  use_density = FALSE,
  pred_model = "RF",
  model_optimize = FALSE,
  sample_size = 1000,
  random_seed = NULL,
  all_pred = FALSE,
  get_verbose = FALSE,
  allow_parallel = FALSE
)

Arguments

target_set

List containing the target datasets. A named list with components X (predictors) and y (response). The predictions are performed to estimate the response values corresponding to X while y is only used to estimate the response distribution parameters.

source_set

List containing the source datasets. A named list with components X (predictors) and y (response). These two sets must be matched and used in both distribution estimation and predictive modeling.

use_density

Flag for using kernel density as distribution estimate instead of histogram counts. Defaults to FALSE.

pred_model

String indicating the underlying predictive model. The currently available options are -

  • RF for random forest regression. If model_optimize = FALSE, builds a model with n_tree = 200 and m_try = 0.4.

  • SVM for support vector regression. If model_optimize = FALSE, builds a model with kernel = "poly", C = 2, and degree = 3.

  • EN for elastic net regression. If model_optimize = FALSE, builds a model with alpha = 0.8 and lambda generated from a 5-fold cross validation.

model_optimize

Flag for model parameter tuning. If TRUE, performs a grid search to optimize parameters and train with the resulting model. If FALSE, uses a set of predefined parameters. Defaults to FALSE.

sample_size

Sample size for estimating distributions of target and source datasets. Defaults to 1e3.

random_seed

Seed for random number generator (for reproducible outcomes). Defaults to NULL.

all_pred

Flag for returning the prediction values in the source space. If TRUE, the function returns a named list with two components- target and source (predictions in the target space and source space, respectively). Defaults to FALSE.

get_verbose

Flag for displaying the progress when optimizing the predictive model i.e., model_optimize = TRUE. Defaults to FALSE.

allow_parallel

Flag for allowing parallel processing when performing grid search i.e., model_optimimze = TRUE. Defaults to FALSE.

Value

If all_pred = FALSE, a vector containing the final prediction values.

If all_pred = TRUE, a named list with two components target and source i.e., predictions in the original target space and in source space, respectively.

Note

Examples

set.seed(8644)

## Generate two dataset with different underlying distributions...
x1 <- matrix(rnorm(3000, 0.3, 0.6), ncol = 3)
dimnames(x1) <- list(paste0("sample", 1:1000), paste0("f", 1:3))
y1 <- 0.3*x1[, 1] + 0.1*x1[, 2] - x1[, 3] + rnorm(1000, 0, 0.05)
x2 <- matrix(rnorm(3000, 0, 0.5), ncol = 3)
dimnames(x2) <- list(paste0("sample", 1:1000), paste0("f", 1:3))
y2 <- -0.2*x2[, 1] + 0.3*x2[, 2] - x2[, 3] + rnorm(1000, 0, 0.05)

## Model datasets using DMTL & compare with a baseline model...
library(DMTL)

target <- list(X = x1, y = y1)
source <- list(X = x2, y = y2)
y1_pred <- DMTL(target_set = target, source_set = source, pred_model = "RF")
y1_pred_bl <- RF_predict(x_train = x2, y_train = y2, x_test = x1)

print(performance(y1, y1_pred, measures = c("MSE", "PCC")))
print(performance(y1, y1_pred_bl, measures = c("MSE", "PCC")))


[Package DMTL version 0.1.2 Index]