Wasserstein Task Embedding for Measuring Task Similarities

Xinran Liu1, Yikun Bai1, Yuzhe Lu1, Andrea Soltoggio2, Soheil Kolouri1
Abstract

Measuring similarities between different tasks is critical in a broad spectrum of machine learning problems, including transfer, multi-task, continual, and meta-learning. Most current approaches to measuring task similarities are architecture-dependent: 1) relying on pre-trained models, or 2) training networks on tasks and using forward transfer as a proxy for task similarity. In this paper, we leverage the optimal transport theory and define a novel task embedding for supervised classification that is model-agnostic, training-free, and capable of handling (partially) disjoint label sets. In short, given a dataset with ground-truth labels, we perform a label embedding through multi-dimensional scaling and concatenate dataset samples with their corresponding label embeddings. Then, we define the distance between two datasets as the 2-Wasserstein distance between their updated samples. Lastly, we leverage the 2-Wasserstein embedding framework to embed tasks into a vector space in which the Euclidean distance between the embedded points approximates the proposed 2-Wasserstein distance between tasks. We show that the proposed embedding leads to a significantly faster comparison of tasks compared to related approaches like the Optimal Transport Dataset Distance (OTDD). Furthermore, we demonstrate the effectiveness of our proposed embedding through various numerical experiments and show statistically significant correlations between our proposed distance and the forward and backward transfer between tasks on a wide variety of image recognition datasets.

\affiliations

1Department of Computer Science, Vanderbilt University
2Department of Computer Science, Loughborough University

Introduction

Learning from a broad spectrum of tasks and transferring knowledge between them is a cornerstone of intelligence, and primates perfectly exemplify this characteristic. Modern Machine Learning (ML) is rapidly moving toward multi-task learning, and there is great interest in methods that can integrate, rapidly adapt, and seamlessly transfer knowledge between tasks. When learning from multiple possibly heterogeneous tasks, it is essential to understand the relationships between the tasks and their fundamental properties. It is, therefore, highly desirable to define (dis)similarity measures between tasks that will allow one to cluster tasks, have better control over the forward and backward transfer, and ultimately require less supervision for learning tasks.

There has been an increasing interest in assessing task similarities and their relationship with forward and backward knowledge transfer among tasks. For instance, various recent works look into the selection of good source tasks/models for a given target task to maximize the forward transfer to the target task achille2019task2vec; zamir2018taskonomy; bao2019information; bhattacharjee2020p2l; fifty2021efficiently. Others have demonstrated the relationship between negative backward transfer (i.e., catastrophic forgetting) and task similarities nguyen2019toward.

Many existing methods for measuring task similarities depend on the choice of model(s), architecture(s), and the training process leite2005predicting; zamir2018taskonomy; achille2019task2vec; khodak2019adaptive; nguyen2020leep; venkitaraman2020task; gao2021information. For example, zamir2018taskonomy; venkitaraman2020task; gao2021information use pre-trained task specified models to measure a notion of forward transfer and define it as task similarity. achille2019task2vec embed tasks into a vector space that relies on a partially-trained network. khodak2019adaptive use the optimal parameters as a proxy for each task and leite2005predicting use the learning curves of a pre-specified model to measure task similarities. Besides being model-dependent, these approaches are often computationally expensive as they involve training deep models (or require pre-trained models).

Model-agnostic task similarity measures provide a fundamentally different approach to quantifying task relationships ben2006analysis; alvarez2020geometric; tran2019transferability; tan2021otce. These methods often measure the similarity between tasks as a function of the similarity between the joint or conditional input/output distributions, sometimes also taking the loss function into account. The classic theoretical results for such similarity measures ben2006analysis; batu2000testing focus on information theoretic divergences between the source and target distributions. More recently, Optimal Transport (OT) based approaches alvarez2020geometric; tan2021otce; xu2022selecting have shown promise in modeling task similarities. Notably, alvarez2020geometric approach measuring task similarities through the lens of a hierarchical OT yurochkin2019hierarchical where they solve an inner OT problem to calculate the label distance between the class-conditional distributions of two supervised learning tasks. The label distance is then incorporated into the transportation cost of an outer OT problem, resulting in a distance between two datasets that integrates both sample and label discrepancies. tan2021otce treats the optimal transport plan between the input distributions of two tasks as a joint probability distribution and use conditional entropy to measure the difference between the two tasks. One major shortcoming of these OT-based approaches is their computational complexity. These methods require the pairwise calculation of OT (or entropy regularized OT) between different tasks, which can be prohibitively expensive in applications requiring frequent evaluations of task similarities, e.g., in continual learning.

We propose a novel OT-based task embedding for supervised learning problems that is model-agnostic and computationally efficient. On the one hand, our proposed approach is similar to (achille2019task2vec) and (peng2020domain2vec), which embed datasets into a vector space in which one can easily measure the difference between tasks, e.g., via the Euclidean distance between embedded vectors. On the other hand, our approach is inspired by the Optimal Transport Dataset Distance (OTDD) alvarez2020geometric framework, and it essentially provides a Euclidean embedding for a hierarchical OT-based distance between tasks. To calculate such a task embedding, we use the Wasserstein embedding framework wang2013linear; kolouri2020wasserstein. Importantly, our approach alleviates the need for pairwise calculation of OT problems between tasks, turning it into a more desirable solution than previously proposed methods.

Contributions. We propose a computationally efficient and model-agnostic task embedding, denoted as Wasserstein Task Embedding (WTE), in which the Euclidean distance between embedded vectors approximates a hierarchical OT distance between the tasks. We provide extensive numerical experiments and demonstrate that: 1) our calculated distances between embedded tasks are highly correlated with the OTDD distance alvarez2020geometric, 2) our proposed embedding and similarity calculation is significantly faster than the OTDD distance, and 3) our proposed similarity measure provides strong and statistically significant correlation with both forward and backward transfer.

Related work

Model-based task similarity. Most existing approaches to measuring task similarity are model-dependent and use forward transferability as a proxy for similarity. For example, zamir2018taskonomy use pre-trained models on source tasks and measure their performance on a target task to obtain an asymmetric notion of similarity between source and target tasks. Following zamir2018taskonomy’s work, dwivedi2019representation measure the transferability in a more efficient manner by applying the Representation Similarity Analysis (RSA) between the trained models (e.g. DNNs) from different tasks. Similarly, nguyen2020leep assume the source and target tasks share the same set of inputs but have different sets of labels, and estimate the transferability by the empirical conditional distribution of target labels given the inputs computed by a pre-trained model on the source task.

Another class of approaches embed the tasks into a vector space and then define the (dis)similarity on the embedded vector representations. achille2019task2vec discuss processing data (images) through a partially trained “probe network” and obtain vector embedding by computing the Fisher information matrix (FIM). The (dis)similarity of two tasks is then computed from the difference between the the FIMs. Similarly, peng2020domain2vec propose a domain (labeled dataset) to vector technique. In particular, given a domain, they feed the data to a pre-trained CNN to compute the Gram matrices of the activations of the hidden convolutional layers, and apply feature disentanglement to extract the domain-specified features. Concatenation of the diagonal entries of Gram matrices and the domain-specified features gives the final domain embedding. These methods, however, highly rely on the pre-trained models and training process, and lack theoretical guarantees. On the opposite side of the spectrum is directly measuring the discrepancy between domains.

Discrepancy measures of domains. Over the years, numerous notions of discrepancy to measure the (dis)similarity of datasets (domains) were proposed, including -norm (batu2000testing), generalized Kolmogorov-Smirnov distance (devroye1996parametric; kifer2004detecting), and loss-oriented discrepancy distance (mansour2009domain). In the context of domain adaptation, generalized Kolmogorov-Smirnov distance (later known as the -distance) is a principled notation of discrepancy, which is a relaxation of total variation. ben2006analysis show that the target performance (generalization error) is controlled by the empirical estimate of the source domain error and the -distance between source and target domains. Another widely used distance is the Maximum Mean Discrepancy (MMD) gretton2006kernel, which captures the (dis)similarity of the embedding of distribution measures in a reproducing kernel Hilbert space. pan2010domain propose to learn transfer components across domains in reproducing kernel Hilbert space using MMD, and show that the subspace spanned by these transfer components preserves data properties. Such domain discrepancy methods, however, can not take labels into account, and thus may not be enough to reflect the similarity of tasks.

Optimal transport based task similarity. In recent years, metrics rooted in the optimal transport problem, e.g., the “Wasserstein distance” villani2009optimal; villani2021topics (or the “earth mover’s distance” rubner2000earth; solomon2014earth), have attracted growing interest in the machine learning community. Wasserstein distance is a rigorous metric of probability measures endowed with desired statistical convergence behavior, in contrast to other classical discrepancies (e.g. KL-divergence, total variation, JS-divergence, Hellinger distance, Maximum mean discrepancy, etc). OT based metrics are widely used in generative modeling (arjovsky2017wasserstein; liu2019wasserstein), domain adaptation (courty2017joint; courty2014domain; alvarez2018gromov), graph embedding kolouri2020wasserstein; xu2019gromov, and neural architecture search (kandasamy2018neural).

alvarez2020geometric propose a notion of distance between two datasets in a supervised learning setting. They introduce Optimal Transport Dataset Distance (OTDD) based on the OT theory, which can be thought as a hierarchical OT distance where the transportation cost measures the distance between samples as well as labels. With the assumption that the label-induced distributions can be approximated by Gaussians, the distance between labels is defined as the Bures-Wasserstein distance.

tan2021otce introduce another OT-based method to measure the transferability, named OTCE (Optimal Transport Conditional Entropy) score. In particular, they first use the entropic optimal transport to estimate domain differences and then use the optimal coupling between the source and target distributions to compute the conditional entropy of the target task given source task. The OTCE is defined by the linear combination of the OT distance and the conditional entropy. Both OTDD and OTCE were shown to be effectively aligned with forward transfer, however, the computation of the pairwise Wasserstein distances among increasing number of datasets remains expensive. This hinders the application of these methods to problems where one needs to perform nearest dataset retrieval frequently (e.g., memory replay approaches in continual learning).

Computation Cost of OT Distance. Calculating the Wasserstein distance involves solving an dimension linear programming and the computational cost is for a pair of -size empirical distributions (pele2009fast). To facilitate the computation, one common method is adding entropic regularization (cuturi2013sinkhorn; peyre2017computational), by which the original linear programming problem is converted into a strictly convex problem. By applying the Sinkhorn-Knopp algorithm (peyre2017computational; chizat2018scaling) to find an - accurate solution, the computational complexity reduces to (altschuler2017near). However, this technique suffers a stability-accuracy trade-off. When the regularity coefficient is high, the objective is biased toward the entropy term; when it is small, the Sinkhorn algorithm will not be numerically stable.

Wasserstein Task Embedding framework. Given tasks
Figure 1: Wasserstein Task Embedding framework. Given tasks and with input space , WTE first map them into as probability distributions and by MDS, then apply WE to get vector and with respect to a fixed reference measure . Here is the size of reference set.

Preliminaries

Multidimensional scaling (MDS)

Multidimensional scaling (MDS) cox2008multidimensional is a non-linear dimensionality reduction approach that embeds samples into an -dimensional Euclidean space while preserving their pairwise distances. Given a set of high-dimensional data and the proximity matrix , where , and denotes the metric in , the goal of MDS is to construct a distance-preserving map from to a lower-dimensional Euclidean space . Depending on the objective and inputs, MDS can be classified into metric MDS and non-metric MDS. Specifically, metric MDS aims to find a map such that

(1)

which can be solved by Algo. 1.

procedure MDS()
     
     Eigen-decomposition
     Rearrange into with descending order of variances
     Rearrange into in correspondence with
     ; return
end procedure
Algorithm 1 Multidimensional Scaling

Note that MDS not only works for Euclidean distances, but also for other dissimilarities such as Wasserstein distances 5609205; hamm2022wassmap.

Wasserstein Distances

Let be Borel probability measures on with finite moment, and the corresponding probability density functions are and , i.e. , . The 2-Wasserstein distance between and is defined as villani2009optimal:

(2)

where is the set of all transport plans between and , i.e. probability measures on with marginals and . We also note that by Brenier theorem Brenier1991PolarFA, given two absolutely continuous probability measures on with densities , there exists a convex function such that is a transport map sending to . Moreover, it is the optimal map in the Monge-Kantorovitch optimal transport problem with quadratic cost:

(3)

where pushes to , denoted by .

Wasserstein Embedding (WE)

Wasserstein Embedding (wang2013linear; kolouri2016continuous; courty2017learning; kolouri2021wasserstein) provides a Hilbertian embedding for probability distributions such that the Euclidean distance between the embedded vectors approximates the 2-Wasserstein distance between the two distributions. Let be a set of a probability distributions over with densities . We fix as the reference measure. Assume is the optimal transport map that pushes to , the Wasserstein embedding of is through a function defined as

(4)

where the is the identity function, i.e., . admits nice properties including but not limited to kolouri2021wasserstein:

  1. is a true metric between and , moreover, it approximates the 2-Wasserstein distance: .

  2. In particular, . Here we leveraged the fact .

Although these hold true for both continuous and discrete measures , we focus on the (uniformly distributed) discrete setting in this paper and provide the following numerical computation details. Let , where is the Dirac delta function centered at and is the set of locations of non-negative mass for . Then the Kantorovich problem with quadratic cost between and can be formulated as

(5)

where the feasible set is

(6)

The optimal transport plan is the minimizer of the above optimization problem, which is solved by linear program at cost , being the number of input samples. To avoid mass splitting, the barycentric projection wang2013linear assigns each in the reference distribution to the center of mass it is sent to and thus outputs an approximated Monge map . Then the Wasserstein Embedding for input is calculated by

(7)

One of the motivations behind Wasserstein embedding is to reduce the need for computing pairwise Wasserstein distances. Given datasets, computation of Wasserstein distances across all distinct pairs is impractically expensive especially when is large, while leveraging Wasserstein embedding, it suffices to calculate only Wasserstein distances and the pairwise Euclidean distances between the embedded distributions.

Method

In this section, we specify the problem setting, review the OTDD framework, and then propose our Wasserstein task embedding (WTE).

Problem Setting

In supervised classification problems, tasks are represented by input-label pairs and can be denoted as , where is the data/inputs and is the labels. We aim to define a similarity/dissimilarity measure for tasks that enables task clustering and allows for better control over the forward and backward transfer.

Label-to-label Bures-Wasserstein distance (left) and label MDS embedding Euclidean distances (middle) between MNIST and USPS datasets, squared error is provided on the right.
Figure 2: Label-to-label Bures-Wasserstein distance (left) and label MDS embedding Euclidean distances (middle) between MNIST and USPS datasets, squared error is provided on the right.

Optimal Transport Dataset Distance (OTDD)

Let be the input set with labels (classes) . For each , let denote the class with label . Following the OTDD framework, let denote the set of data-label pairs. OTDD encodes the label as distribution , where . The ground distance in is then defined by combining the Euclidean distance between the data points and the 2-Wasserstein distance between label distributions:

(8)

Based on this metric, the OT distance between two distributions and on is

(9)

where denotes the set of transport plans between and . Note that Eq. 9 is a hierarchical transport problem, as the transportation cost itself depends on calculation of the Wasserstein distance. To avoid the computational cost of a hierarchical optimal transport problem, alvarez2020geometric replace the Wasserstein distance in Eq. 8 with the Bures-Wasserstein distance (malago2018wasserstein; bhatia2019bures), which assumes that s are Gaussian distributions. Throughout the paper, we consider only the exact-OTDD, as opposed to the entropy-regularized and other variants.

Wasserstein Task Embedding

We define a task-2-vec framework (Fig. 1) using Wasserstein embedding (WE) such that the (squared) Euclidean distance between two vectors approximates the OTDD between the original tasks, and denote this embedding by WTE. We later show in the experiment section that the Euclidean distance between the embedded task vectors is not only highly predictive of forward transferability, but also significantly correlates with the backward transferability (catastrophic forgetting).

Label Embedding via MDS. The combination of optimal transport metric with MDS technique was first introduced as an approach to characterize and contrast the distribution of nuclear structure in different tissue classes (normal, benign, cancer, etc.) 5609205, and further studied in image manifold learning hamm2022wassmap. In short, it seeks to isometrically map probability distributions to vectors in relatively low-dimensional space. We leverage the prior work and define an approximated isometry on the label distributions by 1) calculating the pairwise Wasserstein distances and 2) applying MDS to obtain embedded vectors. We adopt the same simplification as in OTDD, that is, assuming the label distributions are Gaussians to replace Wasserstein distances with the closed form Bures-Wasserstein distance:

(10)

where and denote the mean and covariance matrix of Gaussian distributions. In consistence with previous notations, let us denote the label (MDS) embedding operator by , then

(11)

where are vectors whose dimension is selected to balance the trade-off between accuracy and computation cost tenenbaum2000global. Having both inputs and labels represented as vectors, we concatenate these two components and map the data-label pairs to such that

(12)

where denotes the concatenation operator and the domain is equipped with norm. Fig. 2 shows this approximation performance among labels in MNIST LeCun2005TheMD and USPS 291440 datasets. MDS embeddings can capture the pairwise relationships with a maximum of error by -dimensional vectors.

(Top row) forward transfer error drop and (bottom row) catastrophic forgetting against WTE distance on *NIST (left), Split-CIFAR100 (middle) and Split-Tiny ImageNet (right) over five runs. Pearson’s
Figure 3: (Top row) forward transfer error drop and (bottom row) catastrophic forgetting against WTE distance on *NIST (left), Split-CIFAR100 (middle) and Split-Tiny ImageNet (right) over five runs. Pearson’s and the corresponding -value are reported on top of each experiment setting.

Wasserstein Embedding. By Eq. Wasserstein Task Embedding and Eq. 9, OTDD can be approximated by the squared 2-Wasserstein distance between the distributions over input-(label MDS embedding) pairs, . Then we leverage the Wasserstein embedding framework to embed the updated task distributions into a Hilbert space, with the goal of reducing the cost of computing pairwise Wasserstein distances. Again, we emphasize that this can bring down the cost from quadratic to linear with the number of task distributions.

The WTE algorithm is summarized in Algo. 2. The outputs are the vector representations of input tasks with respect to a pre-determined MDS dimension and WE reference distribution.

procedure WTE()
     Calculate label-to-label distance matrix W (Eq. 10)
     Calculate for all distinct labels (Algo. 1)
     Stack each input with its label vector:
     Calculate the WE (Eq. 7) return
end procedure
Algorithm 2 Wasserstein Task Embedding

Experiments

To assess the effectiveness of our WTE framework, we empirically validate the correlation between WTE distance and forward/backward transferability on several datasets. Moreover, we provide both qualitative and quantitative comparison results with OTDD, and show WTE distance is well aligned with OTDD, and meanwhile is notably faster to compute. We use the MDS toolkit in scikit-learn and the exact linear programming solver in Python Optimal Transport (POT) flamary2021pot library for implementing WE. We carry out the distance calculations on CPU and all the model training experiments on a 24GB NVIDIA RTX A5000 GPU.

Pairwise OTDD (left) and WTE distances (middle) on the *NIST task group, and their correlation diagram (right). Notice that OTDD (Eq.
Figure 4: Pairwise OTDD (left) and WTE distances (middle) on the *NIST task group, and their correlation diagram (right). Notice that OTDD (Eq. 9) is the squared , we report the squared WTE distances and adjust to the same scale according to the cost function. Adjusted WTE distance is strongly correlated with OTDD, with correlation coefficient .

Datasets

We conduct experiments on the following three task groups:

*NIST task group consists of the handwritten digits dataset MNIST LeCun2005TheMD and its extensions EMNIST cohen2017emnist, FashionMNIST xiao2017fashion, KMNIST Clanuwat2018DeepLF along with USPS 291440. We choose the mnist split for EMNIST dataset and thus all tasks contain 10 classes of gray-scale images. All datasets have a training set of 60,000 samples and a test set of 10,000 samples, except USPS, with a total of 9,298 samples. We resize the images from USPS into pixel level to match with the others.

Split-CIFAR100 task group is generated by randomly splitting the CIFAR-100 Krizhevsky2009LearningML dataset with 100 image categories into 10 smaller tasks, each of which is a classification with 10 classes. There are 600 color images in the training set and 100 in the test set per class.

Split-Tiny ImageNet task group follows the same splitting scheme as in Split-CIFAR100. We randomly divide the Tiny ImageNet Le2015TinyIV into 10 disjoint tasks with 20 classes. Each class contains 500 training images, 50 validating images and 50 test images. For better model performance, we first rescale each sample to and then perform a center crop to get pixel images.

Results

To study the transfer behaviors against the WTE distances, we fix a model architecture for each task group. Specifically, we use ResNet18 he2016deep on both *NIST and Split-Tiny ImageNet, and ResNet34 he2016deep on Split-CIFAR100. In the forward transfer setting, for each source-target task pair, we first train the head (i.e., the classifier) of a randomly initialized backbone on the target task, and use the test performance as the baseline. Next, we adapt from a model pre-trained on the source task and finetune the head on the target task. We define the forward transferability of the source-target pair as the performance gain, i.e. error drop when adapting from the source task. To analyse backward transfer, all source tasks are trained jointly during the first phase to avoid task bias, then in the second phase the model learns only the target task and suffers from “forgetting” the previous tasks. We use the catastrophic forgetting, i.e., negative backward transfer as a measure of backward (in)transferability. In implementations of WE, the reference distribution is fixed for each task group, and is randomly generated by upsampling random images at a lower spatial resolution to entail some smooth structure.

Wall-clock computation time comparison on the *NIST (left) and Split-CIFAR100 (right) task groups.
Figure 5: Wall-clock computation time comparison on the *NIST (left) and Split-CIFAR100 (right) task groups.

Fig. 3 summarizes the correlation diagrams between our proposed WTE distance and the forward/backward transferability on the aforementioned three task groups. WTE distance is negatively correlated with forward transferability, and positively correlated with catastrophic forgetting. In all scenarios, the correlation is strong and statistically significant, which confirms the efficacy of WTE distance as a measure of task similarities. We also visualize the comparison between WTE distance and OTDD on the *NIST task group in Fig. 4, showing strong correlation between the two distances.

Computation Complexity

As we mentioned before, OTDD suffers from a prohibitive computational cost as the number of tasks grows large. The pairwise OTDD calculation for a set of tasks requires time in the worst case, where is the largest number of samples among the tasks. WTE distance requires solving only optimal transport problem, leading to complexity. To better demonstrate the efficiency of WTE distance, we report the wall-clock time comparison on the *NIST and Split-CIFAR100 in Fig. 5.

Conclusion

In this paper, we propose Wasserstein task embedding (WTE), a model-agnostic task embedding framework for measuring task (dis)similarities in supervised classification problems. We perform a label embedding through multi-dimensional scaling and leverage the 2-Wasserstein embedding framework to embed tasks into a vector space, in which the Euclidean distance between the embedded points approximates the 2-Wasserstein distance between tasks. We demonstrate that our proposed task embedding distance is correlated with forward and backward transfer on *NIST, Split-CIFAR100 and Split-Tiny ImageNet task groups while being significantly faster than existing methods. In particular, we show statistically significant negative correlation between the WTE distances and the forward transfer, and positive correlation with the catastrophic forgetting (i.e. negative backward transfer). Lastly, we show the alignment of WTE distance with OTDD, but with a significant computational advantage as the number of tasks grows.

References