Source code for mtw.mtw

import numpy as np
from .solver import solver
from .examples_utils import inspector_mtw
try:
    import cupy as cp
    get_module = cp.get_array_module
except ImportError:
    def get_module(x):
        return np


class NotFittedError(AttributeError):
    """Raised if an estimator is used before fitting."""


[docs]class MTW: """A class for MultiTask Regression with Wasserstein penalization. Attributes ---------- """
[docs] def __init__(self, M, alpha=1., beta=0., epsilon=0.1, gamma=1., sigma0=0., stable=True, callback=False, maxiter_ot=1000, maxiter=4000, tol_ot=1e-5, tol_cd=1e-4, tol=1e-5, positive=False, n_jobs=1, gpu=False, **kwargs): """Constructs instance of MTW. Parameters ---------- M: array, shape (n_features, n_features) Ground metric matrix defining the Wasserstein distance. alpha: float >= 0. hyperparameter of the Wasserstein penalty. beta : float >= 0. hyperparameter of the l1 penalty. epsilon: float > 0. OT parameter. Weight of the entropy regularization. gamma: float > 0. OT parameter. Weight of the Kullback-Leibler marginal relaxation. sigma0: float >=0. Lower bound of the noise standard deviation. If positive, the l1 penalty is adaptively scaled to the noise std estimation (corresponds to concomitant MTW, or MWE algorithm). If 0, noise std are not inferred (classic MTW). stable: boolean. optional (default False) if True, use log-domain Sinhorn stabilization from the first iter. if False, the solver will automatically switch to log-domain if numerical errors are encountered. callback: boolean. optional. if True, set a printing callback function to the solver. maxiter_ot: int > 0 maximum Sinkhorn iterations maxiter_cd: int > 0 maximum coordinate descent iterations maxiter: int > 0 maximum outer loop iterations tol_ot: float >=0. relative maximum change of the Wasserstein barycenter. tol_cd: float >=0. relative maximum change of the coefficients in coordinate descent. tol: float >=0. relative maximum change of the coeffcients in the outer loop. positive: boolean. if True, coefficients must be positive. n_jobs: int > 1. number of threads used in coordinate descents gpu: boolean. if True, Sinkhorn iterations are performed on gpus using cupy. """ self.callback = callback self.callback_kwargs = kwargs self.n_jobs = n_jobs self.wyy0 = 0. self.M = M self.xp = get_module(M) self.alpha = alpha self.epsilon = epsilon self.gamma = gamma self.beta = beta self.stable = stable self.maxiter_ot = maxiter_ot self.tol_ot = tol_ot self.tol = tol self.maxiter = maxiter self.positive = positive self.n_jobs = n_jobs self.tol_cd = tol_cd self.sigma0 = sigma0 self.gpu = gpu self._set_callback()
def _set_callback(self): """Set callback if `callback` is True.""" self.callback_f = None if self.callback: self.callback_f = inspector_mtw(**self.callback_kwargs)
[docs] def fit(self, X, Y, verbose=True): """Launch MTW solver. Parameters ---------- X: numpy array (n_tasks, n_samples, n_features). Regression data. Y: numpy arrays (n_tasks, n_samples,). Target data. Returns ------- instance of self. """ self.t_ot = 0. self.t_cd = 0. coefs1, coefs2, bar1, bar2, log, sigmas = \ solver(X, Y, M=self.M, alpha=self.alpha, beta=self.beta, epsilon=self.epsilon, gamma=self.gamma, sigma0=self.sigma0, stable=self.stable, tol=self.tol, callback=self.callback_f, maxiter=self.maxiter, tol_ot=self.tol_ot, maxiter_ot=self.maxiter_ot, positive=self.positive, n_jobs=self.n_jobs, tol_cd=self.tol_cd, gpu=self.gpu, verbose=verbose) self.coefs1_ = coefs1.copy() self.coefs2_ = coefs2.copy() self.coefs_ = coefs1 - coefs2 self.sigmas_ = sigmas self.barycenter1_ = bar1 self.barycenter2_ = bar2 self.barycenter_ = bar1 - bar2 self.log_ = log self.t_ot += log["t_ot"] self.t_cd += log["t_cd"] return self