Skip to contents

Calculates the Bures-Wasserstein (BW) barycenter of a list of symmetric positive-definite (SPD) matrices. The BW barycenter is the geodesic mean under the BW metric. This function implements an iterative algorithm.

Usage

bures_wasserstein_barycenter(
  S_list,
  weights = NULL,
  initial_mean = NULL,
  max_iter = 50,
  tol = 1e-07,
  regularize_epsilon = 1e-06,
  verbose = FALSE,
  damping = 0.5
)

Arguments

S_list

A list of SPD matrices (p x p).

weights

Optional. A numeric vector of non-negative weights for each matrix in `S_list`. If NULL (default), uniform weights (1/N) are used. Must sum to 1 if provided, or will be normalized.

initial_mean

Optional. A p x p SPD matrix to use as the initial estimate for the barycenter. If NULL, the (weighted) arithmetic mean of `S_list` is used after regularization.

max_iter

Integer, maximum number of iterations for the fixed-point algorithm. Default: 50.

tol

Numeric, tolerance for convergence. The algorithm stops when the Frobenius norm of the difference between successive estimates of the barycenter is below this tolerance. Default: 1e-7.

regularize_epsilon

Numeric, small positive value for regularizing input matrices and intermediate results to ensure positive definiteness. Default: 1e-6.

verbose

Logical, if TRUE, prints iteration information. Default: FALSE.

damping

Numeric, damping factor for the update step. Default 0.5.

Value

A p x p SPD matrix representing the Bures-Wasserstein barycenter. Returns NULL if computation fails or inputs are invalid.

Examples

# S1 <- matrix(c(2,1,1,2), 2,2)
# S2 <- matrix(c(3,0,0,3), 2,2)
# S_list_bw <- list(S1, S2)
# bw_mean <- bures_wasserstein_barycenter(S_list_bw, verbose = TRUE)
# print(bw_mean)

# Weighted example
# S3 <- matrix(c(1.5,0.5,0.5,1.5), 2,2)
# bw_mean_weighted <- bures_wasserstein_barycenter(list(S1,S2,S3), weights=c(0.5,0.25,0.25))
# print(bw_mean_weighted)