mtw.MTW

class mtw.MTW(M, alpha=1.0, beta=0.0, epsilon=0.1, gamma=1.0, sigma0=0.0, stable=True, callback=False, maxiter_ot=1000, maxiter=4000, tol_ot=1e-05, tol_cd=0.0001, tol=1e-05, positive=False, n_jobs=1, gpu=False, **kwargs)[source]

A class for MultiTask Regression with Wasserstein penalization.

__init__(self, M, alpha=1.0, beta=0.0, epsilon=0.1, gamma=1.0, sigma0=0.0, stable=True, callback=False, maxiter_ot=1000, maxiter=4000, tol_ot=1e-05, tol_cd=0.0001, tol=1e-05, positive=False, n_jobs=1, gpu=False, **kwargs)[source]

Constructs instance of MTW.

Parameters
M: array, shape (n_features, n_features)

Ground metric matrix defining the Wasserstein distance.

alpha: float >= 0.

hyperparameter of the Wasserstein penalty.

betafloat >= 0.

hyperparameter of the l1 penalty.

epsilon: float > 0.

OT parameter. Weight of the entropy regularization.

gamma: float > 0.

OT parameter. Weight of the Kullback-Leibler marginal relaxation.

sigma0: float >=0.

Lower bound of the noise standard deviation. If positive, the l1 penalty is adaptively scaled to the noise std estimation (corresponds to concomitant MTW, or MWE algorithm). If 0, noise std are not inferred (classic MTW).

stable: boolean. optional (default False)

if True, use log-domain Sinhorn stabilization from the first iter. if False, the solver will automatically switch to log-domain if numerical errors are encountered.

callback: boolean. optional.

if True, set a printing callback function to the solver.

maxiter_ot: int > 0

maximum Sinkhorn iterations

maxiter_cd: int > 0

maximum coordinate descent iterations

maxiter: int > 0

maximum outer loop iterations

tol_ot: float >=0.

relative maximum change of the Wasserstein barycenter.

tol_cd: float >=0.

relative maximum change of the coefficients in coordinate descent.

tol: float >=0.

relative maximum change of the coeffcients in the outer loop.

positive: boolean.

if True, coefficients must be positive.

n_jobs: int > 1.

number of threads used in coordinate descents

gpu: boolean.

if True, Sinkhorn iterations are performed on gpus using cupy.

Methods

__init__(self, M[, alpha, beta, epsilon, …])

Constructs instance of MTW.

fit(self, X, Y[, verbose])

Launch MTW solver.

fit(self, X, Y, verbose=True)[source]

Launch MTW solver.

Parameters
X: numpy array (n_tasks, n_samples, n_features).

Regression data.

Y: numpy arrays (n_tasks, n_samples,).

Target data.

Returns
instance of self.