transport_cost.numeric {gridOT}R Documentation

Optimal Transport Cost

Description

Calculate the optimal transport cost.

Usage

## S3 method for class 'numeric'
transport_cost(x, y, wx, wy, p = 1, sorted = FALSE, threshold = 1e-15, ...)

transport_cost(x, ...)

## S3 method for class 'otgridtransport'
transport_cost(x, threshold = 1e-15, ...)

## S3 method for class 'otgrid'
transport_cost(x, ...)

## S3 method for class 'data.frame'
transport_cost(x, costm, ...)

Arguments

x

a vector of points; a data frame with columns from, to and mass specifying the optimal transport plan; an object of class "otgridtransport" or "otgrid", in the latter case ... must be the arguments of pivot_measure.

y

second vector of points.

wx

weight vector of the first vector of points.

wy

weight vector of the second vector of points.

p

the power \geq 1 of the cost function.

sorted

logical value indicating whether or not a and b are sorted.

threshold

small value that indicates when a value is considered to be zero.

...

further arguments (for pivot_measure if x is an object of class "otgrid").

costm

cost matrix of the transport

Details

In case of two-dimensional grids, the pivot measure is used to calculate the optimal transport cost.

For one-dimensional optimal transport, the cost function is given by c(x, y) = | x - y |^p. In this case, the north-west-corner algorithm is used.

Value

the optimal transport cost or, in case of two-dimensional case, an object of class "otgridtransport" with element cost that contains it.

See Also

pivot measure pivot_measure

Examples

## one-dimensional example
set.seed(1)
a <- 1:5
wa <- rep(1/5, 5)
b <- 1:6
wb <- runif(6)
wb <- wb / sum(wb)
transport_cost(a, b, wa, wb, p = 1)

## two-dimensional example
x <- otgrid(cbind(0:1, 1:0))
y <- otgrid(cbind(1:0, 0:1))

# first calculate pivot manually
pm <- pivot_measure(x, y)
pm <- transport_cost(pm)
print(pm$cost)

# or just
pm2 <- transport_cost(x, y)
print(pm2$cost)

# or from transport plan and cost matrix
costm <- transport_costmat(pm)
tp <- transport_df(pm)
print(transport_cost(tp$df, costm))

[Package gridOT version 1.0.1 Index]