Fix-A-Step: Effective Semi-supervised Learning from Uncurated Unlabeled Sets

Zhe Huang1, Mary-Joy Sidhom1, Benjamin S. Wessler2, Michael C. Hughes1
Abstract

Semi-supervised learning (SSL) promises gains in accuracy compared to training classifiers on small labeled datasets by also training on many unlabeled images. In realistic applications like medical imaging, unlabeled sets will be collected for expediency and thus uncurated: possibly different from the labeled set in represented classes or class frequencies. Unfortunately, modern deep SSL often makes accuracy worse when given uncurated unlabeled sets. Recent remedies suggest filtering approaches that detect out-of-distribution unlabeled examples and then discard or downweight them. Instead, we view all unlabeled examples as potentially helpful. We introduce a procedure called Fix-A-Step that can improve heldout accuracy of common deep SSL methods despite lack of curation. The key innovations are augmentations of the labeled set inspired by all unlabeled data and a modification of gradient descent updates to prevent following the multi-task SSL loss from hurting labeled-set accuracy. Though our method is simpler than alternatives, we show consistent accuracy gains on CIFAR-10 and CIFAR-100 benchmarks across all tested levels of artificial contamination for the unlabeled sets. We further suggest a real medical benchmark for SSL: recognizing the view type of ultrasound images of the heart. Our method can learn from 353,500 truly uncurated unlabeled images to deliver gains that generalize across hospitals.

\DeclareUnicodeCharacter

0010!!! \affiliations 1 Dept. of Computer Science, Tufts University, Medford, MA
2 Division of Cardiology, Tufts Medical Center, Boston, MA

1 Introduction

A key roadblock to applying supervised learning to real applications is the need to assemble a large-enough labeled dataset for the intended task. Modern deep learning pipelines are especially data-hungry. In many cases, the acquisition of a large dataset of unlabeled features is rather affordable. However, providing reliable labels for each example is cost-prohibitive, requiring expensive, time-consuming work, often from human experts. This tradeoff is especially apt in our motivating application: classifying medical images. Images are collected in the course of routine care and easily available by querying a hospital’s electronic records. However, labeling these images well often requires clinical staff with years of training to spend minutes per image.

If only a tiny labeled set is available but access to a huge unlabeled set of images is possible, one promising approach is semi-supervised learning (SSL) (Zhu, 2005; van Engelen and Hoos, 2020). Recent years have seen amazing progress on standard benchmarks such as recognizing digits from photos of address numbers on houses (SVHN; Netzer et al. (2011)). With only 100 labeled examples per digit class, a supervised neural net’s error rate is roughly 12%. Using a large unlabeled set, the FixMatch SSL method (Sohn et al., 2020) delivers error below 2.5%, while even more recent work has pushed below 2% (Xu et al., 2021; Han et al., 2020).

Unfortunately, common benchmarks like SVHN may be too optimistic. In real tasks, unlabeled sets will be collected automatically at scale for convenience, and thus may differ from the labeled set in terms of represented classes or class frequencies, among other differences. To maintain the efficiency that motivates SSL, the effort required to curate the unlabeled set – that is, apply labels and assess differences – remains prohibitive. To be effective in practice, SSL must reliably improve accuracy despite uncurated data.

Off-the-shelf SSL using mismatched unlabeled sets often predicts worse than ignoring the unlabeled set altogether (Oliver et al., 2018). Many recent methods do try to be robust to uncurated unlabeled data (see Tab. 1). Broadly, most follow the same intuitive direction: learn to identify examples in the unlabeled set that are out-of-distribution (OOD), and then remove or downweight these. However, we find existing work delivers at best modest gains in accuracy.

Larger, more reliable gains are needed to enable SSL in real uncurated applications. This study makes 3 contributions toward this goal backed by reproducible experiments111Code: https://github.com/tufts-ml/fix-a-step.

First, we challenge the dominant paradigm that handles uncurated unlabeled sets by filtering out OOD examples. Our experiments suggest that even perfect OOD filtering (which is unrealistic in practice) does not perform well. Instead, we argue for a new paradigm: OOD images from uncurated unlabeled sets are potentially helpful.

Second, following this paradigm we introduce a new training procedure called Fix-A-Step that delivers accuracy gains from uncurated unlabeled sets. When applied to repair several deep SSL methods across a full range of labeled-unlabeled mismatch levels, our Fix-A-Step improves predictions better than alternative methods while being substantially simpler too.

Finally, we offer a new SSL benchmark using real uncurated medical images that can assess cross-hospital generalization. For too long, SSL methods advances have been evaluated only on the same repurposed generic object recognition datasets. Using three inter-operable open-access datasets – TMED (Huang et al., 2021b), CAMUS (Leclerc et al., 2019), and Unity (Howard et al., 2021) – we pursue a clinically-relevant problem: recognizing the view type of an echocardiogram image of the heart. Future methods for learning from limited data can follow our reproducible protocol. We hope this benchmark enables authentic SSL applications in medicine and ultimately improves the efficiency and quality of care for patients with heart disease.

2 Background and Related Work

We pursue semi-supervised learning (SSL) (Zhu, 2005; van Engelen and Hoos, 2020) for the specific problem of image classification with deep neural networks (Oliver et al., 2018). We can observe images (represented as -dimensional vectors) as well as corresponding labels for classes of interest. The goal is to train predictors from both a labeled dataset of feature-labeled pairs and an unlabeled dataset containing feature vectors only.

Training for Deep SSL.

While many SSL paradigms have been tried (Kingma et al., 2014; Kumar, Sattigeri, and Fletcher, 2017; Nalisnick et al., 2019), the dominant approaches for semi-supervised training of deep image classifiers today continue to modify standard objectives for discriminative neural nets by adding a regularization term using unlabeled data (Miyato et al., 2019; Sohn et al., 2020). This approach trains a neural net probabilistic classifier with weights by solving the optimization problem:

(1)

Here, the first term is a labeled-set-only cross entropy loss and the second term is a method-specific unlabeled-set loss. A key hyperparameter is the unlabeled-loss-weight , which balances the two terms. Approaches such as the Pi-model (Laine and Aila, 2017), Pseudo-Label (Lee, 2013), Mean-Teacher (Tarvainen and Valpola, 2017), Virtual Adversarial training (VAT) (Miyato et al., 2019), and FixMatch (Sohn et al., 2020) all fit this objective, with variations in (1) the choice of function for ; (2) how each image may be altered via standard data augmentation techniques; (3) procedures within the computation of that produce a perturbed image that should be consistent with ; and (4) optimization routines to solve for .

Uncurated SSL. Unlabeled sets collected automatically at scale are by construction uncurated, meaning their contents (features and true labels) are intended to be similar to the target labeled set but not carefully validated. When the unlabeled set contains images from classes other than the classes represented in the labeled set, others call this “open-set” SSL (Yu et al., 2020a). More formally, if we were to apply labels to the unlabeled set , the set of such labels may include an unknown number of classes beyond the known classes in the labeled set, and in the worst case may not even include any examples from some (or all) known classes. Open-set SSL is a special case of uncurated SSL, because a real uncurated dataset may also differ somewhat in feature distribution from the labeled set. Our CIFAR evaluation focuses on differences in class labels, later in Sec. 5 we use truly uncurated unlabeled sets from medical applications.

Seminal work by Oliver et al. Oliver et al. (2018) shows that off-the-shelf SSL performance deteriorates when the unlabeled contents differ in label composition from the labeled set. The left panel of Fig. 1 shows our results for how existing SSL methods (like VAT or Pi-Model) behave when evaluated on CIFAR-10 images where we purposefully build the unlabeled set to be mismatched from the labeled set. Under modest mismatch, many SSL methods perform worse than a labeled-only baseline that ignores the unlabeled set.

Several approaches have been proposed to remedy the issue of effectively learning from abundant unlabeled data in the wild despite its lack of curation. Such methods are often called “safe” SSL, because their goal is to perform no worse than labeled-set-only methods for any unlabeled set. Tab. 1 summarizes several previous works. Broadly, these existing safe SSL approaches do not perform as well as needed (see Fig. 1 far right). Furthermore, to our knowledge these safe methods have not yet been evaluated on authentic uncurated data from potential SSL application domains like medicine.

Method Acc. Paradigm Applicability Extra Complexity Realistic Eval.
Fix-A-Step (ours) 80.2 OOD helpful any none Heart2Heart
TOOR (Huang, Yang, and Gong, 2022) 78.5 OOD helpful any Separate NN for OOD discrimination none
DS3L (Guo et al., 2020) 74.7 OOD harmful any 3x train time due to bilevel optimization none
MTCF (Yu et al., 2020a) 77.0 OOD harmful any Extra OOD head, curriculum learning none
UASD (Chen et al., 2020) 78.2 OOD harmful fixed none none
OpenMatch (Saito, Kim, and Saenko, 2021) n/a OOD harmful FixMatch only Extra one-vs-all OOD detector / class none
Table 1: Comparison of related work on open-set/safe SSL. Acc means accuracy on the CIFAR-10 6-animal task (defined in Sec. 4) with 400 labeled examples / class and an open-set unlabeled set (50% labeled-unlabeled mismatch). We use a Mean-Teacher base model (Tarvainen and Valpola, 2017) when applicable and numbers come from our implementation except if marked * (copied from cited paper). OpenMatch result unavailable; our task differs from their experiments (our unlabeled set has greater mismatch). Paradigm: how each method treats out-of-distribution (OOD) images in the unlabeled set. Applicability: whether the method might be possible with different unlabeled losses . Extra Complexity: additional neural networks, layers, or runtime concerns that exceed a standard SSL deep classifier like MixMatch. Realistic Eval.: evaluation beyond creating “artificial” unlabeled sets from datasets like CIFAR, TinyImageNet, ImageNet, or Large-scale Scene Understanding (LSUN).

Related work: Open-set SSL that filters out OOD. Among various previous attempts to handle realistic open-set unlabeled sets, most methods focusing on detecting and eliminating OOD samples, under the assumption that OOD samples can only harm the ultimate accuracy of an SSL classifier. Chen et al. (2020)’s UASD ensembles model predictions temporally to produce probability predictions for unlabeled samples, and used confidence-based thresholding to filter out OOD samples. Yu et al. (2020a) propose a multi-task curriculum learning framework (MTCF) to update the network parameters and anomaly score of unlabeled data alternately to detect OOD samples. Guo et al. (2020)’s Deep Safe Semi-supervised Learning (DS3L) employs meta-learning ideas to automatically down-weight the OOD samples, in order to reduce the harm caused by OOD samples in the unlabeled set. Saito, Kim, and Saenko (2021)’s OpenMatch unifies FixMatch with novelty detection to learn representations of inliers while rejecting outliers. These methods have made some significant contribution toward solving the class mismatch problem. However, they all neglected the potential value in OOD samples, assuming OOD samples will harm SSL training. We will show later how OOD samples can be useful in SSL training.

Related work: Open-set SSL beyond filtering. In recent work in parallel to ours, Huang, Yang, and Gong (2022) suggest that OOD unlabeled data may not be “completely useless.” They develop an approach called TOOR that first trains a model to recognize in- versus out-of-distribution examples, and then, viewing the OOD examples as from a related domain, pursue adversarial domain adaptation to try to “recycle” OOD examples. Similarly, Luo et al. (2021) try to reduce the distribution gap between ID and OOD samples by using a style transfer model. The transformed OOD samples are then used as if they were ID samples to enforce consistency against their own augmented versions. Finally, Huang et al. (2021a) utilize the OOD unlabeled samples in the first stage where they use all the ID and OOD samples in an auxiliary pre-text task to pretrain the network for better initial feature representation learning. However, they still filter out OOD samples in the second stage as they assume OOD samples would harm classification accuracy of ID samples.

Some distantly related work explores different problem settings. Huang et al. (2021c) consider SSL for the setting where both class and feature distribution are mismatched. Cao, Brbic, and Leskovec (2022) focus on transductive learning for SSL in “open worlds” where novel classes appear in the unlabeled test set.

Background on SSL benchmarks. SSL methods continue to focus a few datasets intended for fully-supervised image classification, such as SVHN, CIFAR-10, CIFAR-100, and ImageNet. This is a problem because these data are post-hoc repurposed for SSL, dropping known labels to create unlabeled sets in artificial fashion. The resulting unlabeled sets are too curated: images usually come from the same classes as the labeled set with similar frequencies. As argued above, the real applications that motivate SSL require an easy-to-acquire unlabeled set that is uncurated.

Recent research has further identified problems with CIFAR and ImageNet. First, 3% of CIFAR-10 and 10% of CIFAR-100 test images have perceptually-indistinguishable duplicates in the train set (Barz and Denzler, 2020). This questions whether high-scoring methods are memorizing rather than truly generalizing. Second, a notable fraction (5%) of the labels in the test sets of CIFAR-100 and ImageNet data are incorrect (Northcutt, Athalye, and Mueller, 2021). More generally, overuse of the same benchmarks over decades may lead to over-optimistic assessments of heldout error rates (Yadav and Bottou, 2019) and may privilege methods that exploit shortcuts or biases in the available data that hurt true generalization (Tsipras et al., 2020; Geirhos et al., 2020). Given this background, we argue that new SSL benchmarks motivated by intended applications are sorely needed to help ensure the next-generation of SSL methods delivers on its promise of generalization.

3 Methods

We have designed a training procedure for deep SSL classifiers that we call Fix-A-Step, which is short for Fix via Augmentation and Step direction modification. Fix-A-Step can be applied to any SSL method that minimizes an SSL objective matching Eq. (1) via gradient descent. The primary goal of Fix-A-Step is to make any SSL classifier far more robust to uncurated unlabeled data.

Input: Labeled set , Unlabeled set (uncurated)
Output: Trained weights
Hyperparameters (: unique to Fix-A-Step)

  • Sharpening temperature for SoftPseudoLabel

  • Shape of dist. for MixMatchAug

  • Max. iterations , Step size , Initial weights

  • Unlabeled-loss weight per iter

1:  for iter until converged do
2:     
3:     
4:     
5:     
6:     
7:     if  then
8:        
9:     else
10:        
11:     end if
12:  end for
13:  return w
Algorithm 1 Fix-A-Step Training

The overall Fix-A-Step training algorithm is provided in Alg. 1. Two phases of Fix-A-Step happen when visiting each minibatch during gradient descent. First, in the augmentation phase (lines 3-4), our insight is the unlabeled set might be helpful for creating useful augmentations, even when uncurated, by injecting realistic diversity (for a motivating experiment, see App. A.2). Inspired by MixMatch (Berthelot et al., 2019), we transform each labeled pair using another pair drawn either from the labeled set or the unlabeled set (where only is known, so use soft pseudo-label predictions for , see Alg. C.1). Given and , we build a new labeled pair via MixUp (Zhang et al., 2017) interpolation (see Alg. C.2), then use that pair to compute the labeled loss.

Second, in the step direction phase (lines 5-10), we prioritize the labeled loss in parameter updates, only using the unlabeled loss if it improves labeled-set performance. At each batch, we compute two gradient vectors, one for each term in the loss: Let and let . The update for weights using step size is then

(2)

In the top case, we do the standard steepest descent update that minimizes the two-term SSL objective in Eq. (1). In the bottom case, we perform an alternative update, using only the labeled-term gradient. This two-case construction tries to ensure that SSL learning does not harm labeled set performance. Formally, we can show that each case of the gradient updates in Eq. (2) adjusts weights in a descent directions for the labeled set loss at the current minibatch.

Definition 1: Descent direction of loss . For any loss function parameterized by weight vector , a vector is a descent direction of at if it satisfies (Boyd and Vandenberghe, 2004).

Lemma 1: The update in Eq. (2) steps in a descent direction of the labeled loss at the current minibatch. We prove for each of the two cases in Eq. (2). Top case: Here by assumption the inner product is positive. This implies that is a descent direction, because and thus

(3)

Bottom case: is a descent direction for by definition.

While Lemma 1 provides a justification for our approach, we cannot formally guarantee the labeled-set loss will not increase after each step, for the same reasons that stochastic gradient descent (SGD) does not always decrease its loss after each minibatch update. First, a descent direction of a small minibatch may not be a descent direction of the entire dataset. Second, even though the direction of the step points locally downhill, the length of the step matters; if the step size is too large, the loss may increase. Nevertheless, with proper step size tuning SGD has been wildly successful despite following minibatch-specific descent directions without formal guarantees of non-increasing loss. Thus far, we find our approach also successful in practice.

Geometric Intuition. Recall that two vectors and have positive inner product if and only if the angle between the vectors is below 90 degrees (less than perpendicular). Thus, we can motivate the alternative lower case in Eq. (2) geometrically. We usually care most about classifier accuracy, so when the two gradients point in opposing directions (bottom case), we might be better off ignoring the unlabeled and updating parameters using only the labeled loss.

Inspiration from multi-task learning. Our design of the step direction modification in Eq. (2) is motivated by previously suggested step direction modifications for multi-task learning in the presence of a “main” task and an “auxiliary” task (Du et al., 2020). This is part of a broader thread of work on step modification for multi-task (Yu et al., 2020b) and continual learning (Lopez-Paz and Ranzato, 2017). To our knowledge, such ideas have not yet been suggested or validated for SSL, or been combined with augmentation.

Comparison to OOD filtering. Using the Fix-A-Step update, we know that each (stochastic) gradient step could beneficially reduce the labeled set loss, given a small-enough step size . We argue our step direction modification is safer than simply learning to downweight individual terms in the unlabeled loss. Without care, the latter could still take steps in a problematic non-descent direction of the labeled loss.

Simplicity compared to related work. We emphasize a key overall advantage of Fix-A-Step is extreme simplicity. Beyond the modest cost of MixMatch-like augmentation, we compute exactly the same losses and gradients as any standard deep SSL solving Eq. (1) would. Each possible weight parameter update is also simple and straightforward. Determining which update to perform depends purely on computing the inner product between two vectors and , which is efficient on modern hardware and adds negligible runtime cost. Table 1 suggests our approach is favorable to other open-set or safe SSL approaches in this regard. There is no added complexity from extra backward passes through the classifier, extra neural networks that must be trained for OOD discrimination, no need for curriculum learning, and no expensive bi-level optimization problem to solve.

4 Experiments on CIFAR

400 examples/class
50 examples/class
Figure 1: Accuracy on CIFAR-10 6 animal task. Accuracy on test images of animals (y-axis) as unlabeled set mismatch (percentage of non-animal classes represented, x-axis) increases. Column 1 (from left): Previous SSL methods trained in standard fashion. Col. 2: SSL methods trained with our Fix-A-Step. Col. 3: SSL methods with perfect OOD filtering of the unlabeled set (removing all non-animal images before training). Col. 4: Previous methods designed for open-set or safe SSL. UASD (marked ) taken from its publication, others from our experiments. Top row: 400 examples/class; Bottom: 50 examples/class.

Our open source code uses PyTorch Paszke et al. (2019) and allows reproducing each experiment (see App. E for details). Following Oliver et al. (2018), for all methods we use the same ResNet-28-2 architecture Zagoruyko and Komodakis (2016), apply standard augmentation (random crops, flips) on the labeled set, and regularize via weight decay.

SSL Baselines. We compared to 6 SSL methods (Pi-Model, Mean-Teacher, Pseudo-label, VAT, MixMatch, and FixMatch) as well as the baseline that minimizes labeled loss on the labeled set (“labeled only”). We also compare to 3 methods intended for open-set/safe SSL: UASD, DS3L, and MTCF. When possible, we use our own adapted implementations of baselines, to ensure that architectures, training, and hyperparameters are comparable and reproducible. Throughout, if a result is copied from another publication, we mark that via an asterisk ().

Training. Following choices in original implementations, each method is trained using either Adam with fixed learning rate or SGD with a cosine-annealing schedule for learning rate (Sohn et al., 2020). We found cosine-annealing and a slow linear ramp-up schedule for the unlabeled-loss-weight particularly helpful for several baselines (see App. E). Each training run used one NVIDIA A100 GPU.

4.1 CIFAR-10 Protocol and Results

6-animal task for CIFAR-10.

We pursue the “6-animal” task designed by Oliver et al. (2018) to artificially create unlabeled sets at different levels of mismatch with the labeled set. We build a labeled set of the 6 animal classes (dog, cat, horse, frog, deer, bird) in CIFAR-10, across two training set sizes: 50 labeled images per class and 400 per class. We form an unlabeled set of images/class from 4 selected classes, some animal and some non-animal (car, truck, ship, airplane). The percentage of non-animal classes is denoted by . If , we recover the standard “closed-set” SSL setting. At , the unlabeled set has no classes in common, and the OOD-filtering paradigm suggests that we should ignore the unlabeled set entirely. For details on the unlabeled set construction as varies, see App A.1

Hyperparameters. All baselines use hyperparameters suggested by previous work for the CIFAR-10 6 animal task (see App. E). In rare cases, if a baseline underperformed we retuned values to maximize validation set accuracy. We did no hyperparameter tuning for Fix-A-Step, fixing throughout and inheriting other hyperparameters from the base SSL method.

Results on 6-animal. In Fig. 1, we compare the accuracy of different methods at recognizing the 6 animal classes in the test set, as the mismatch percentage increases. Across two different training set sizes (rows), we compare 4 different training scenarios (columns, best read left to right): methods trained in the standard way (“off-the-shelf”), methods trained using Fix-A-Step, methods trained in the standard fashion but with perfect OOD filtering applied to the unlabeled set before training so that only known-class samples remain, and methods intended for safe SSL. The perfect OOD filtering column essentially shows the best-possible case for methods under the OOD-is-harmful paradigm.

We highlight several findings from Fig. 1: 1. Existing “safe” SSL methods perform little better than labeled-set-only. In the right-most column, “safe” SSL methods (UASD, DS3L, MTCF) roughly match or fall below the dashed line above . FixMatch, which was not intended to be “safe”, dominates all 3 methods.

2. Fix-A-Step improves all SSL methods in almost all settings. Despite its relative simplicity, Fix-A-Step is quite effective, as seen in the raised accuracies from the first to the second column across almost all methods and values. In Fig. A.2, we further demonstrate that Fix-A-Step’s gains are robust across multiple random train/test splits.

3. Perfect OOD filtering is not enough. The third column shows that perfect OOD filtering still leads to underwhelming accuracy for all . The gains of our Fix-A-Step over this best-case filtering suggest that our OOD-is-helpful paradigm should be prioritized over filtering.

4.2 CIFAR-100 Protocol and Results

50-class task for CIFAR-100. Using the larger CIFAR-100 dataset, we follow the experimental design of Chen et al. (2020) to create a class distribution mismatch scenario by using classes 1-50 as labeled classes, and classes 25-75 as unlabeled classes. To assess a more extreme level of unlabeled set “contamination”, we further create a 100% class distribution mismatch scenario: classes 1-50 are labeled classes; classes 51-100 unlabeled.

Hyperparameters. All methods use the same hyperparameters as in CIFAR-10 experiments, without any retuning.

Figure 2: Accuracy on CIFAR-100 50-class task. Each bar represents the accuracy of a method with either off-the-shelf training (blue) or Fix-A-Step (orange). We try 2 scenarios: 50% labeled/unlabeled class mismatch (left panel) and 100% class mismatch (right). All numbers were produced by our implementation except those marked (UASD).

Results on CIFAR-100 50-class. Fig. 2 compares “off-the-shelf” SSL methods using standard training (blue bars) and Fix-A-Step (orange). We see consistent gains in both 50% and 100% mismatch settings, even without tuning hyper-parameters. Other methods (UASD, DS3L, MTCF) underperform, perhaps due to untuned hyperparameters.

4.3 Ablations and Sensitivity Analysis

Ablations. Here we quantify how each of Fix-A-Step’s two key components (Augmentation and Gradient step modification) perform in isolation. Tab. 2 compares accuracy on the 6 animal task at 400 examples/class and . Gradient step modification alone increases accuracy around 0.5 to 1.5% across five base SSL methods. Augmentation alone is more effective, increasing accuracy 2.5 to 4.5%. When combined, we consistently see the largest gains. For further results at other mismatch levels, see App. A

Sensitivity analysis. There are two hyperparameters unique to Fix-A-Step: sharpening temperature and the Beta shape . For simplicity, we set and throughout. Since deep SSL is often sensitive to hyperparameters (Oliver et al., 2018; Sohn et al., 2020), we further analyse other possible choices: and . Fig. A.3 shows that Fix-A-Step delivers consistent and similar accuracy gains across all tested settings, and thus does not appear overly sensitive.

Pi-Model MT VAT Pseudo FixMatch
off-the-shelf 73.90 73.75 73.87 75.45 78.45
+G only 74.50 74.33 75.35 75.92 79.73
+A only 77.25 78.38 77.87 77.88 81.53
+A&G (ours) 79.18 79.23 78.52 78.72 82.73
Table 2: Ablations for CIFAR-10 6 animal task, reporting accuracy for each SSL method (columns) if we only use our augmentation (+A), only use our gradient step modification (+G), or use the combination (+A&G) that defines Fix-A-Step. We bold the best result and all others within 1 percentage point. Setting: 400 examples/class, .

5 Experiments on Heart2Heart Benchmark

TMED2 (Boston, USA; 4 views) Transfer to Unity (UK; 3 views) Transfer to CAMUS (France, 2 views)
Figure 3: Balanced accuracy for echocardiogram view classification (Heart2Heart benchmark). Methods are trained on TMED-2 images to distinguish 4 view types: PLAX, PSAX, A2C, and A4C. TMED2’s 353,500 image unlabeled set is uncurated, representing a superset of possible view types including the 4 known classes. Bar height gives mean balanced accuracy across 3 models trained on different splits of TMED-2 (error bars indicate min/max). Left: Evaluation on heldout TMED-2 images. Center: Evaluation of TMED-2-trained classifiers on PLAX, A2C, and A4C images from Unity dataset (17 sites in the UK). Right: Evaluation of TMED-2-trained classifiers on A2C and A4C views from CAMUS dataset (1 site in France).

In pursuit of realistic evaluation, we consider a reproducible, clinically-relevant SSL task that we call Heart2Heart. The key question is this: can we generalize classifiers of ultrasound images of the heart from one hospital system to heart images from different hospitals in other countries with no overlap in staff or ownership. For training, we use the Tufts Medical Echocardiogram Dataset 2 (TMED-2) (Huang et al., 2022, 2021b). TMED2 contains a small labeled set of echocardiogram studies and a larger uncurated unlabeled set. Thanks to common device standards, these images are interoperable with two other open-access datasets of “echo” images: Unity from 17 hospitals in the UK (Howard et al., 2021) and CAMUS from a hospital in France (Leclerc et al., 2019). We emphasize that all datasets are deidentified and open to any researcher for non-commercial purposes. We will release code so this Heart2Heart evaluation of SSL can be readily reproduced for future SSL methods.

Classification task: View type of 2D TTE image. Trans-thoracic echocardiography (TTE) is a gold-standard way to non-invasively capture the heart’s 3-dimensional anatomy for measurement and diagnosis. A human sonographer wields a handheld transducer over the patient’s chest at different angles in order to provide clear views of each facet of the heart. A routine TTE scan of a patient, called a study, produces many images (median=68, 10-90th percentile range=27-97 in TMED-2), each showing a canonical 2D view of the heart. No view type annotation is recorded with any image. In later analysis, clinicians must manually search over all recorded images to find a desired view type. Automated interpretation of echocardiograms must also be able to pick out specific view types before any useful measurements or diagnosis can be made, making view classification a prediction task with immediate potential clinical impact (Madani et al., 2018a; Huang et al., 2021b).

TMED-2 provides set of labeled images of four specific view types, known as PLAX, PSAX, A2C, and A4C, gathered from certified annotators. Reliably identifying these views would be particularly useful for key valve disease diagnostic tasks (Huang et al., 2021b). TMED-2 also contains an uncurated, truly unlabeled set of 353,500 images from routine TTEs from 5486 patient-studies. This is an authentic open-set SSL task, because at least 9 canonical view types frequently appear in routine TTEs (Mitchell et al., 2019).

Protocol. We train SSL methods on images from 56 labeled studies in TMED-2’s training set as well as all unlabeled studies (353,500 images). We report balanced accuracy on each split’s test set of 120 studies (2104 view-labeled images per split). We then assess generalization of these U.S.-hospital-trained classifiers to images from different European hospitals. We report balanced accuracy on 7231 available PLAX, A2C, and A4C images from the Unity dataset, as well as all 2000 images (A2C and A4C views only) in the CAMUS dataset.

Results on Heart2Heart. Fig. 3 shows classifier performance on held-out data from all 3 datasets. TMED-2 evaluations (first panel) show that our Fix-A-Step procedure yields gains across all tested SSL methods (Pi-Model, VAT, FixMatch). Only with Fix-A-Step do the methods convincingly outperform the labeled-only baseline.

External evaluation on Unity and CAMUS (Fig. 3 panels 2-3) show that these gains transfer to new hospitals. Each tested SSL method performs better with Fix-A-Step than standard training. Across splits we see larger performance variation on Unity and CAMUS than on TMED-2, which highlights the difficulty of generalizing across hospitals as well as importance of external validation. All methods perform worse on CAMUS than other datasets; see App. B for further investigations. Overall, this Heart2Heart benchmark task shows the promise of Fix-A-Step to deliver gains from unlabeled data that generalize better than alternatives.

6 Discussion

In summary, this paper makes three contributions to deep SSL image classification. First, we argue that uncurated or OOD data in the unlabeled set can be quite helpful, and should not merely be filtered out; experiments in Fig. 1 show that even with perfect OOD filtering most SSL methods deliver underwhelming accuracy gains. Additional evidence for the helpfulness of OOD unlabeled data is provided in App. A.2. Second, building on insights from multi-task learning (Du et al., 2020), our Fix-A-Step SSL training procedure uses gradient step modifications that prioritize the labeled-set loss, leading to effective SSL that is substantially simpler than alternatives (no new loss terms or extra neural networks). Finally, we hope our new Heart2Heart benchmark for SSL evaluation inspires robust studies of clinical model transportability across global populations.

Limitations. Our experiments on artificial unlabeled sets focused exclusively on mismatch in the labels. We did not explore shifts in the features , though some modest shifts may be present in TMED2 data due to its uncurated nature. Our work’s exclusive focus is image classification. Fix-A-Step could apply to other types of data whose feature vectors have fixed dimension (required by MixUp). More work is needed on alternative data types or multi-modal SSL.

Our clinical applications have shown only a proof-of-concept for generalization of view classifiers across hospitals. Far more work is needed to rigorously assess generalizability and translate to improve clinical workflow and care.

Outlook. Fix-A-Step is a promising new first-line approach to SSL that can unlock the promise of uncurated unlabeled sets. We hope future work explores both augmentation and step direction modification ideas further, while maintaining our focus on simplicity and reproducibility.

Ethics statement. Semi-supervised image classification has many positive applications. Indeed, work on SSL is often specifically motivated by the promise of improved efficiency in environments where labels are expensive and time-consuming as is the case in medical imaging (Huang et al., 2021b; Madani et al., 2018b). However, care must be taken to ensure that automated methods actually benefit patients and do not widen current disparities (Celi et al., 2022). Our present Heart2Heart evaluations are an important step beyond single-center evaluations though do not reflect the true geographic and racial diversity of many patient populations. While all images used in our Heart2Heart task are completely de-identified and come from public open-access datasets (and thus did not require ethics review), we stress the responsibility we carry as researchers to protect the best interests of the individuals who contributed data.

References

  • Barz and Denzler (2020) Barz, B.; and Denzler, J. 2020. Do We Train on Test Data? Purging CIFAR of Near-Duplicates. Journal of Imaging, 6(6): 41.
  • Berthelot et al. (2019) Berthelot, D.; Carlini, N.; Goodfellow, I.; Papernot, N.; Oliver, A.; and Raffel, C. 2019. MixMatch: A Holistic Approach to Semi-Supervised Learning. In Advances in Neural Information Processing Systems.
  • Boyd and Vandenberghe (2004) Boyd, S. P.; and Vandenberghe, L. 2004. Sec. 9.2: Descent Methods. In Convex Optimization. Cambridge, UK ; New York: Cambridge University Press.
  • Cao, Brbic, and Leskovec (2022) Cao, K.; Brbic, M.; and Leskovec, J. 2022. Open-World Semi-Supervised Learning. In International Conference on Learning Representations (ICLR). arXiv.
  • Celi et al. (2022) Celi, L. A.; Cellini, J.; Charpignon, M.-L.; Dee, E. C.; Dernoncourt, F.; Eber, R.; Mitchell, W. G.; Moukheiber, L.; Schirmer, J.; et al. 2022. Sources of Bias in Artificial Intelligence That Perpetuate Healthcare Disparities—A Global Review. PLOS Digital Health, 1(3): e0000022.
  • Chaudhry et al. (2018) Chaudhry, A.; Ranzato, M.; Rohrbach, M.; and Elhoseiny, M. 2018. Efficient lifelong learning with a-gem. arXiv preprint arXiv:1812.00420.
  • Chen et al. (2020) Chen, Y.; Zhu, X.; Li, W.; and Gong, S. 2020. Semi-Supervised Learning under Class Distribution Mismatch. Proceedings of the AAAI Conference on Artificial Intelligence, 34(04): 3569–3576.
  • Du et al. (2020) Du, Y.; Czarnecki, W. M.; Jayakumar, S. M.; Farajtabar, M.; Pascanu, R.; and Lakshminarayanan, B. 2020. Adapting Auxiliary Losses Using Gradient Similarity. arXiv:1812.02224.
  • Farajtabar et al. (2020) Farajtabar, M.; Azizan, N.; Mott, A.; and Li, A. 2020. Orthogonal gradient descent for continual learning. In International Conference on Artificial Intelligence and Statistics, 3762–3773. PMLR.
  • Geirhos et al. (2020) Geirhos, R.; Jacobsen, J.-H.; Michaelis, C.; Zemel, R.; Brendel, W.; Bethge, M.; and Wichmann, F. A. 2020. Shortcut Learning in Deep Neural Networks. Nature Machine Intelligence, 2(11): 665–673.
  • Gong et al. (2021) Gong, C.; Wang, D.; Li, M.; Chen, X.; Yan, Z.; Tian, Y.; Chandra, V.; et al. 2021. NASViT: Neural Architecture Search for Efficient Vision Transformers with Gradient Conflict aware Supernet Training. In International Conference on Learning Representations.
  • Guo et al. (2020) Guo, L.-Z.; Zhang, Z.-Y.; Jiang, Y.; Li, Y.-F.; and Zhou, Z.-H. 2020. Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data. In International Conference on Machine Learning, 10.
  • Han et al. (2020) Han, T.; Gao, J.; Yuan, Y.; and Wang, Q. 2020. Unsupervised Semantic Aggregation and Deformable Template Matching for Semi-Supervised Learning. In Advances in Neural Information Processing Systems (NeurIPS), 11.
  • Howard et al. (2021) Howard, J. P.; Stowell, C. C.; Cole, G. D.; Ananthan, K.; Demetrescu, C. D.; Pearce, K.; Rajani, R.; Sehmi, J.; Vimalesvaran, K.; et al. 2021. Automated Left Ventricular Dimension Assessment Using Artificial Intelligence Developed and Validated by a UK-Wide Collaborative. Circulation: Cardiovascular Imaging, 14(5): e011951.
  • Huang et al. (2021a) Huang, J.; Fang, C.; Chen, W.; Chai, Z.; Wei, X.; Wei, P.; Lin, L.; and Li, G. 2021a. Trash to Treasure: Harvesting OOD Data with Cross-Modal Matching for Open-Set Semi-Supervised Learning. In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), 8290–8299. Montreal, QC, Canada: IEEE.
  • Huang et al. (2021b) Huang, Z.; Long, G.; Wessler, B.; and Hughes, M. C. 2021b. A New Semi-supervised Learning Benchmark for Classifying View and Diagnosing Aortic Stenosis from Echocardiograms. In Proceedings of the 6th Machine Learning for Healthcare Conference. PMLR.
  • Huang et al. (2022) Huang, Z.; Long, G.; Wessler, B. S.; and Hughes, M. C. 2022. TMED 2: A Dataset for Semi-Supervised Classification of Echocardiograms. In In DataPerf: Benchmarking Data for Data-Centric AI Workshop.
  • Huang et al. (2021c) Huang, Z.; Xue, C.; Han, B.; Yang, J.; and Gong, C. 2021c. Universal Semi-Supervised Learning. In Advances in Neural Information Processing Systems.
  • Huang, Yang, and Gong (2022) Huang, Z.; Yang, J.; and Gong, C. 2022. They Are Not Completely Useless: Towards Recycling Transferable Unlabeled Data for Class-Mismatched Semi-Supervised Learning. IEEE Transactions on Multimedia, 1–1.
  • Kingma et al. (2014) Kingma, D. P.; Mohamed, S.; Rezende, D. J.; and Welling, M. 2014. Semi-Supervised Learning with Deep Generative Models. In Advances in Neural Information Processing Systems.
  • Kumar, Sattigeri, and Fletcher (2017) Kumar, A.; Sattigeri, P.; and Fletcher, T. 2017. Semi-Supervised Learning with GANs: Manifold Invariance with Improved Inference. In Advances in Neural Information Processing Systems.
  • Laine and Aila (2017) Laine, S.; and Aila, T. 2017. Temporal Ensembling for Semi-Supervised Learning. In International Conference on Learning Representations.
  • Leclerc et al. (2019) Leclerc, S.; Smistad, E.; Pedrosa, J.; Østvik, A.; Cervenansky, F.; Espinosa, F.; Espeland, T.; Berg, E. A. R.; Jodoin, P.-M.; et al. 2019. Deep Learning for Segmentation Using an Open Large-Scale Dataset in 2D Echocardiography. IEEE Transactions on Medical Imaging, 38(9): 2198–2210.
  • Lee (2013) Lee, D.-H. 2013. Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks. In Workshop on Challenges in Representation Learning at ICML, 2.
  • Lopez-Paz and Ranzato (2017) Lopez-Paz, D.; and Ranzato, M. 2017. Gradient Episodic Memory for Continual Learning. In Advances in Neural Information Processing Systems, 10.
  • Luo et al. (2021) Luo, H.; Cheng, H.; Meng, F.; Gao, Y.; Li, K.; Zhang, M.; and Sun, X. 2021. An Empirical Study and Analysis on Open-Set Semi-Supervised Learning. arXiv:2101.08237.
  • Madani et al. (2018a) Madani, A.; Arnaout, R.; Mofrad, M.; and Arnaout, R. 2018a. Fast and Accurate View Classification of Echocardiograms Using Deep Learning. npj Digital Medicine, 1(1): 1–8.
  • Madani et al. (2018b) Madani, A.; Ong, J. R.; Tibrewal, A.; and Mofrad, M. R. 2018b. Deep Echocardiography: Data-Efficient Supervised and Semi-Supervised Deep Learning towards Automated Diagnosis of Cardiac Disease. NPJ digital medicine, 1(1): 1–11.
  • Mitchell et al. (2019) Mitchell, C.; Rahko, P. S.; Blauwet, L. A.; Canaday, B.; Finstuen, J. A.; Foster, M. C.; Horton, K.; Ogunyankin, K. O.; Palma, R. A.; et al. 2019. Guidelines for Performing a Comprehensive Transthoracic Echocardiographic Examination in Adults: Recommendations from the American Society of Echocardiography. Journal of the American Society of Echocardiography, 32(1): 1–64.
  • Miyato et al. (2019) Miyato, T.; Maeda, S.-I.; Koyama, M.; and Ishii, S. 2019. Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(8): 1979–1993.
  • Nalisnick et al. (2019) Nalisnick, E.; Matsukawa, A.; Teh, Y. W.; Gorur, D.; and Lakshminarayanan, B. 2019. Hybrid Models with Deep and Invertible Features. In International Conference on Machine Learning, 4723–4732. PMLR.
  • Netzer et al. (2011) Netzer, Y.; Wang, T.; Coates, A.; Bissacco, A.; Wu, B.; and Ng, A. Y. 2011. Reading Digits in Natural Images with Unsupervised Feature Learning. In NeurIPS Workshop on Deep Learning and Unsupervised Feature Learning.
  • Northcutt, Athalye, and Mueller (2021) Northcutt, C. G.; Athalye, A.; and Mueller, J. 2021. Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks. In Advances in Neural Information Processing Systems Datasets and Benchmarks Track.
  • Oliver et al. (2018) Oliver, A.; Odena, A.; Raffel, C. A.; Cubuk, E. D.; and Goodfellow, I. 2018. Realistic Evaluation of Deep Semi-Supervised Learning Algorithms. In Advances in Neural Information Processing Systems.
  • Paszke et al. (2019) Paszke, A.; Gross, S.; Massa, F.; Lerer, A.; Bradbury, J.; Chanan, G.; Killeen, T.; Lin, Z.; Gimelshein, N.; Antiga, L.; et al. 2019. Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32.
  • Riemer et al. (2018) Riemer, M.; Cases, I.; Ajemian, R.; Liu, M.; Rish, I.; Tu, Y.; and Tesauro, G. 2018. Learning to learn without forgetting by maximizing transfer and minimizing interference. arXiv preprint arXiv:1810.11910.
  • Saito, Kim, and Saenko (2021) Saito, K.; Kim, D.; and Saenko, K. 2021. OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers. In Advances in Neural Information Processing Systems, 12.
  • Shi et al. (2021) Shi, Y.; Seely, J.; Torr, P. H.; Siddharth, N.; Hannun, A.; Usunier, N.; and Synnaeve, G. 2021. Gradient matching for domain generalization. arXiv preprint arXiv:2104.09937.
  • Sohn et al. (2020) Sohn, K.; Berthelot, D.; Carlini, N.; Zhang, Z.; Zhang, H.; Raffel, C. A.; Cubuk, E. D.; Kurakin, A.; Li, C.-L.; et al. 2020. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. In Advances in Neural Information Processing Systems.
  • Tarvainen and Valpola (2017) Tarvainen, A.; and Valpola, H. 2017. Mean Teachers Are Better Role Models: Weight-averaged Consistency Targets Improve Semi-Supervised Deep Learning Results. Advances in neural information processing systems, 30: 1195–1204.
  • Tsipras et al. (2020) Tsipras, D.; Santurkar, S.; Engstrom, L.; Ilyas, A.; and Madry, A. 2020. From ImageNet to Image Classification: Contextualizing Progress on Benchmarks. arXiv:2005.11295.
  • van Engelen and Hoos (2020) van Engelen, J. E.; and Hoos, H. H. 2020. A Survey on Semi-Supervised Learning. Machine Learning, 109(2): 373–440.
  • Wu et al. (2021) Wu, N.; Huang, Z.; Shen, Y.; Park, J.; Phang, J.; Makino, T.; Gene Kim, S.; Cho, K.; Heacock, L.; Moy, L.; et al. 2021. Reducing False-Positive Biopsies using Deep Neural Networks that Utilize both Local and Global Image Context of Screening Mammograms. Journal of Digital Imaging, 34(6): 1414–1423.
  • Xu et al. (2021) Xu, Y.; Ding, J.; Zhang, L.; and Zhou, S. 2021. DP-SSL: Towards Robust Semi-supervised Learning with A Few Labeled Samples. In Advances in Neural Information Processing Systems (NeurIPS), 13.
  • Yadav and Bottou (2019) Yadav, C.; and Bottou, L. 2019. Cold Case: The Lost MNIST Digits. In Advances in Neural Information Processing Systems, volume 32.
  • Yu et al. (2020a) Yu, Q.; Ikami, D.; Irie, G.; and Aizawa, K. 2020a. Multi-Task Curriculum Framework for Open-Set Semi-Supervised Learning. In European Conference on Computer Vision (ECCV). arXiv.
  • Yu et al. (2020b) Yu, T.; Kumar, S.; Gupta, A.; Levine, S.; Hausman, K.; and Finn, C. 2020b. Gradient Surgery for Multi-Task Learning. In Advances in Neural Information Processing Systems, 13.
  • Zagoruyko and Komodakis (2016) Zagoruyko, S.; and Komodakis, N. 2016. Wide residual networks. arXiv preprint arXiv:1605.07146.
  • Zeng et al. (2019) Zeng, G.; Chen, Y.; Cui, B.; and Yu, S. 2019. Continual learning of context-dependent processing in neural networks. Nature Machine Intelligence, 1(8): 364–372.
  • Zhang et al. (2017) Zhang, H.; Cisse, M.; Dauphin, Y. N.; and Lopez-Paz, D. 2017. Mixup: Beyond Empirical Risk Minimization. arXiv preprint arXiv:1710.09412.
  • Zhu (2005) Zhu, X. 2005. Semi-Supervised Learning Literature Survey. Technical Report Technical Report 1530, Department of Computer Science, University of Wisconsin Madison.

Supplementary Material



In this supplement, we provide:

  • Sec. A: CIFAR Experiments: Further Details, Results, and Analysis

  • Sec. B: Heart2Heart Experiments: Further Details, Results, and Analysis

  • Sec. C: Methods Supplement: Algorithms for Aug+SoftLabel and MixMatchAug

  • Sec. D: Related Work Supplement: Further Discussion and Analysis

  • Sec. E: Reproducibility Supplement: Hyperparameters, Settings, etc.

Appendix A CIFAR Experiments: Details, Results, and Analysis

a.1 CIFAR-10 6-animal task mismatch description

In Table A.1 we define which classes form the labeled and unlabeled set at each level of mismatch for the CIFAR-10 6 animal task. This exactly follows  Oliver et al. (2018) and creates more challenging scenarios than other “mismatch” tasks on CIFAR-10 tried previously (for example,  (Saito, Kim, and Saenko, 2021) examine a case with all 10 classes in the unlabeled set).

Labeled set Unlabeled set
Bird, Cat, Deer, Dog, Frog, Horse Deer, Dog, Frog, Horse
Bird, Cat, Deer, Dog, Frog, Horse Airplane, Dog, Frog, Horse
Bird, Cat, Deer, Dog, Frog, Horse Airplane, Car, Frog, Horse
Bird, Cat, Deer, Dog, Frog, Horse Airplane, Car, Ship, Horse
Bird, Cat, Deer, Dog, Frog, Horse Airplane, Car, Ship, Truck
Table A.1: Definition of labeled/unlabeled class mismatch scenario in CIFAR-10 6 animal task. We bolded the non-animal classes in unlabeled set that are not in the labeled set. All included classes are represented with equal frequency.

a.2 Illustration: OOD Unlabeled Data can be Helpful

We motivate the hypothesis that unlabeled data even from out-of-distribution (OOD) classes could be useful by an experiment testing the off-the-shelf performance of MixMatch (Berthelot et al., 2019) on the CIFAR-10 6 animal task (with 400 images per class in training set). We compare MixMatch with and without perfect OOD filtering under three mismatch levels and . Results are shown in Fig. A.1. Counter-intuitively we see that perfect OOD Filtering leads to clearly worse performance at all mismatch levels. This finding appears robust across 5 random train/test splits. This result suggests to us that unlabeled data, even with OOD classes, could be useful via MixMatch style augmentation. Note that our suggested Fix-A-Step procedure provides further safeguards via the gradient step modification, which are not used in Fig. A.1.

Figure A.1: CIFAR-10 6-animal, 400 examples/class. Results average across 5 train/test splits (shaded area shows standard deviation). Both methods use the same hyper-parameters for fair comparison.

a.3 Ablation Study across different level of contamination

Expanding on the ablation table in the main paper, in Tab. A.2 we show complete ablation results (comparing augmentation only, gradient step modification only, or both) across all tested values of the mismatch in labeled-vs-unlabeled class content .

Mismatch Mismatch
+A only +G only +A&G (Fix-A-Step) Pi-Model 78.32 81.90 78.48 82.83 Mean-Teacher 79.57 84.18 80.60 83.02 VAT 79.15 83.88 79.47 83.10 Pseudo-label 77.43 79.03 78.30 79.98 FixMatch 86.40 86.35 86.17 88.00 +A only +G only +A&G (Fix-A-Step) Pi-Model 77.45 79.12 78.00 79.80 Mean-Teacher 77.70 81.35 78.28 87.77 VAT 76.65 82.27 78.35 82.63 Pseudo-label 76.58 79.28 77.25 79.60 FixMatch 83.20 84.44 83.85 86.63
Mismatch Mismatch
+A only +G only +A&G (Fix-A-Step) Pi-Model 76.03 78.70 76.57 78.97 Mean-Teacher 76.35 82.22 78.18 80.15 VAT 75.90 79.43 77.37 79.56 Pseudo-label 75.42 78.33 75.65 78.77 FixMatch 81.60 83.28 81.80 85.45 +A only +G only +A&G (Fix-A-Step) Pi-Model 75.00 77.82 74.77 79.35 Mean-Teacher 74.33 79.63 74.52 81.08 VAT 74.87 80.82 75.20 81.20 Pseudo-label 76.85 78.38 77.15 78.83 FixMatch 81.05 83.03 81.33 84.75
Table A.2: Ablation analysis on CIFAR-10 6 animal task, examining how accuracy changes for each SSL method if we only use our augmentation (+A), only use our gradient step modification (+G), and use the combination (+A&G) which constitutes our Fix-A-Step. Each panel shows results for a fixed value of the mismatch percentage describing the overlap in classes between labeled and unlabeled set. For each method, we bold the best result. Setting: 400 examples/class.

a.4 Robustness of results across multiple train/test splits

In the main paper, we report results on CIFAR-10 6 animal task across many baselines methods. For each baseline, we use only one train/test split due to the tremendous computation required to compare all baselines. In Fig. A.2, for a subset of methods we assess the robustness of the conclusions of that experiment to repeated evaluation using multiple separate training/test splits.

We train the labeled-set-only baseline, FixMatch with and without Fix-A-Step, and Mean-Teacher with and without Fix-A-Step for 5 random splits of the data, across two levels of mismatch ( and ). Results are in Fig. A.2. Broadly, we suggest that our conclusion that Fix-A-Step delivers successful accuracy gains holds even across 5 splits: both FixMatch and MeanTeacher show notable gains across both levels of mismatch . Notably, MeanTeacher plus Fix-A-Step appears quite competitive with off-the-shelf FixMatch, and FixMatch plus Fix-A-Step is the best of all.

Figure A.2: CIFAR10, 400 examples/class. Results average across 5 random splits of the data. Error bar showing the standard deviation across the 5 split.

a.5 Sensitivity to hyperparameters

Since Deep SSL methods could be sensitive to hyper-parameters, we conduct sensitivity analysis to see how Fix-A-Step behave under different choice of sharpening temperature and Beta distribution shape . We analyzed the performance of Fix-A-Step using Mean-Teacher base model for common choice of and (totally 4 combinations). Results in Fig. A.3 shows that Fix-A-Step does not appear overly sensitive to the two additional hyper-parameters.

Figure A.3: CIFAR10, 400 examples/class. Vanilla MT: baseline MT training. Fix-A-Step MT1: MT base Fix-A-Step with , , . Fix-A-Step MT2: MT base Fix-A-Step with , . Fix-A-Step MT3: MT base Fix-A-Step with , . Fix-A-Step MT4: MT base Fix-A-Step with , . Results average across 5 random split of the data. Error bar showing standard deviation across the 5 split.

Appendix B Heart2Heart Experiments: Details, Results, and Analysis

b.1 Preprocessing TMED-2 data

We applied for access to the TMED-2 data via the form on the website (https://TMED.cs.tufts.edu), and downloaded the shared folder of data from the provided cloud-based link after approval. Images (as 112x112 PNG images) and associated view labels (in CSV files) are readily available in the provided shared folder for download.

Train/validation/test splits. To form our labeled sets for training, we used the provided train/test splits of the fully-labeled set with the smallest training set size (56 studies available for both training and validation). While larger labeled training sets are possible, we selected this smaller size as the most compelling use case for SSL. We wanted to answer the question: how well can we do with very little labeled data but a large pile of unlabeled data.

View label selection. Among available view labels, we chose PLAX, PSAX, A4C, and A2C as the 4 classes to focus on for our Heart2Heart view type classifier. The original TMED-2 labeled set, as described in Huang et al. (2022), contains an additional view type label that they called A2CorA4CorOther, which is a super-category that contains possible view types distinct from PLAX and PSAX (including A2C, A4C, and other possible classes like A5C). For simplicity, we excluded that class in our Heart2Heart experiments.

b.2 Preprocessing Unity data

We downloaded the Unity data by going to their website (https://data.unityimaging.net). Once at their website, go to the ’Latest Data Release’ section and download the images. For the view labels, go to https://data.unityimaging.net/additional.htmland download the csv file under the ’View’ section.

In the Unity dataset, along with PLAX, A2C, and A4C views, there are also A3C and A5C. For the purposes of these experiments, we filtered out all A3C and A5C images.

Disclaimer: These view labels were done by one human so there may be some errors in the labeling.

The raw Unity data came in .png format, so first we converted all the pngs to a tiff format. Then we converted them to gray-scale, padded the shorter axis to achieve a square aspect ratio, and resized it to 112 x 112 pixels.

b.3 Preprocessing CAMUS data

We acquired the CAMUS data by going to their website (http://camus.creatis.insa-lyon.fr/challenge/#challenges). Once you get to their website, link on the first link, register on that website, and then you’ll be free to download the dataset.

In the CAMUS dataset, in addition to having view labels (’2CH’ in their dataset is ’A2C’ and likewise ’4CH’ is ’A4C’), they also label whether the view was taken in the end diastolic (ED) or end systolic (ES) portion of the cardiac cycle. We separated and took note of these labels, but we found no significant differences in the results.

The raw CAMUS data came in .mhd format, a special file types used specifically for medical imaging. Through conversations with data creators, we discovered that the resolution for these images was lower in the x direction than the y direction and the way .mhd files compensate for a lower resolution is by adjusting the space between the pixels in that direction (indicated by the ’Element Spacing’ field). In order to convert to a standardized tiff file representation (where the spacing between pixels is uniform across width and height) we shrank the image in the y direction as:

(4)

where y is the original location (number of pixels) in the y direction, and are the spacing of pixels in the y and x directions (as given in the Element Spacing metadata), and is the new location the y direction.

After this transformaion, the images were converted them to gray-scale, padded the shorter axis to achieve a square aspect ratio, and resized to 112 x 112 pixels.

b.4 Further investigation of CAMUS performance

In our main paper’s Fig. 3, we assess how well our TMED-2-trained models, which get balanced accuracy in the range on TMED-2 test set, generalize to other external datasets. The models did reasonably generalize to the Unity dataset (balanced accuracy ranges from ), however the same successes were not seen with the CAMUS dataset (balanced accuracy ).

Visualizing differences. To investigate, we visually compared images from TMED-2, Unity, and CAMUS. While Unity and TMED-2 looked similar, when comparing TMED-2 and CAMUS there are clear discrepancies in pixel intensity, likely from the use of a different ultrasound machine and different conventions standard intensity values and normalization. Fig.  B.1 below provides sample images of the two datasets and a summary histogram of pixel intensity (aggregated across all images).

Idea: Simple quantile transformation. To quickly try to remedy this discrepancy, we tried to transform the CAMUS images such that the pixel intensity distribution more closely resembles that of TMED-2. In this transformation, we first mapped all the target pixels to its empirical quantile (value between 0-1) and then we mapped that value to a pixel intensity in the source (TMED-2) images via the empirical inverse CDF. To see the effects of this transformation on the CAMUS images and on the pixel intensity histogram, look at the right-most panel of Fig.  B.1

Figure B.1: A sample of images from the TMED-2 dataset, CAMUS dataset, and the same CAMUS pictures except under a pixel transformation to match the pixel intensity of TMED-2

Results after transform. The accuracy of the TMED-2-trained classifiers on both untransformed and transformed CAMUS data can be viewed in Fig.  B.2. Like we said in the main paper, Fix-a-Step clearly improves SSL models in classifying CAMUS view types for the untransformed data. However, while the transformation itself seems to help model performance overall, Fix-a-Step doesn’t seem to help as much in the transformed dataset (some gains for VAT, but both FixMatch and Pi-model the before-after difference seems negigible). Importantly, Fix-A-Step is still competitive with its base method, just not notably superior to it. Much more work is needed here. In the future, we hope to explore other ways to improve performance on the CAMUS dataset so it reaches accuracy levels seen in the Unity dataset.

Figure B.2: Evaluation of the SSL methods from the paper on untransformed and transformed CAMUS images

Further investigation: Differences across splits. After investigating the results of the three data splits, we noticed that the first split seemed to significantly under perform on the CAMUS dataset, specifically with the A4C class. When we took a look at the Unity data for this split, we also noticed that, while the discrepancy wasn’t as drastic, the A4C class did under perform when compared to the other classes. These results can be clearly seen in Tab.  B.1. In all method-dataset pairs, A4C performs significantly worse than other classes.

A hypothesis we have is that this data split significantly under represents A4C and thus is not able to predict it as well. The reason why we don’t see TMED-2 and Unity significantly under perform in this split in terms of total balance accuracy is because the other classes are a significant portion of their test sets so they’re not as affected by A4C under performing; however, 50% of CAMUS is A4C, so that dataset is affected to a higher degree. However, we are unsure as to why the A4C class accuracy in CAMUS does significantly worse than the A4C class accuracy in Unity. We will investigate this discrepancy further in the future. We think this open problem makes our Heart2Heart benchmark especially interesting.

Methods CAMUS Unity
A4C A2C A4C A2C PLAX
Labeled-Only 28.8 98.0 76.9 93.3 96.1
Pi-model 27.1 96.8 81.9 93.1 98.2
Pi-model w/ FAS 45.4 98.0 87.5 94.1 99.6
VAT 26.6 98.3 77.0 93.1 97.7
VAT w/ FAS 27.5 99.4 84.6 96.2 95.5
Fix-Match 34.1 96.0 79.8 94.8 96.9
Fix-Match w/FAS 36.3 97.9 83.5 95.2 99.0
Table B.1: Class accuracies for data split 1 across methods for the Unity dataset and untransformed CAMUS dataset. Bolded are the lowest class accuracies for each dataset-method pair.

Appendix C Methods Supplement

Here, we provide implementation details of the two subprocedures in our Fix-A-Step training (Alg. 1). Both procedures were originally proposed by MixMatch (Berthelot et al., 2019), we provide them here in common notation as the rest of our paper for clarity.

First, the algorithm Aug+SoftLabel is in Alg. C.1. This procedure consumes a batch of raw images from the unlabeled set and returns two transformed batches, with a common set of “sharpened” soft (probabilistic) labels.

Second, the algorithm MixMatchAug is in Alg. C.2. This procedure consumes a batch of raw labeled data, and produces a transformed batch of the same size.

Input: Unlabeled batch features
Output: Augmented features , Soft pseudo labels
Hyperparameters

  • Sharpening temperature

Procedure

1:  for each image in  do
2:     
3:     
4:              // Probability vector predicted by neural net
5:     
6:        // Non-negative vector, sharpened by element-wise power
7:     
8:            // Normalize to ‘‘soft’’ label (proba. vector)
9:     Add to
10:     Add to
11:     Add to
12:  end for
13:  return , ,
Algorithm C.1 Augment and Soft-Pseudo-Label

Input: Labeled batch , Unlabeled batch ,
Output: Transformed labeled batch
Hyperparameters

  • Shape of dist.

1:  for image-label pair in labeled batch  do
2:     
3:     
4:     
5:     
6:     
7:     Add to
8:     Add to
9:  end for
10:  return
Algorithm C.2 MixMatchAug : Transformation of Labeled Set

Appendix D Related Work Supplement

Gradient step modifications.

Recently, across many sub-areas of ML that optimize of a multi-task loss, modifying the direction of gradient descent updates during training has born fruit.

The idea of gradient matching has been proposed to solve catastrophic forgetting problems in continual learning (Lopez-Paz and Ranzato, 2017; Chaudhry et al., 2018; Riemer et al., 2018; Zeng et al., 2019; Farajtabar et al., 2020). In Lopez-Paz and Ranzato (2017), the author proposed a method called Gradient Episodic Memory (GEM), where they used a memory bank to store representative samples of previous tasks. While minimizing the loss on current task, they use the inner product of the gradient between current and previous tasks as an inequality constraint. In Chaudhry et al. (2018), Averaged GEM (A-GEM) is proposed as an improved version of GEM. A-GEM ensures that at every training step the average episodic memory loss over the previous tasks does not increase. Riemer et al. (2018) formally proposed the transfer-interference trade-off perspective for looking at the application of gradient matching in continual learning, which defines whether helpful transfer or interference occurs between two labeled examples in terms of the inner product of gradients with respect to parameters evaluated at those examples. Zeng et al. (2019) developed Orthogonal Weights Modification (OWM) method to project the weight updates to the orthogonal direction to the subspace spanned by previously learned task inputs while Farajtabar et al. (2020) projects the new task’s gradient to the direction that is perpendicular to the gradient space of previous tasks.

Similar ideas were later used in multi-task learning (Du et al., 2020; Yu et al., 2020b), domain generalization Shi et al. (2021) and neural architecture search (Gong et al., 2021).

Appendix E Reproducibility Supplement: Hyperparameters, Settings, etc.

e.1 Codebase

Our work builds upon several public repositories that represent either official or well-designed third-party implementations of popular SSL methods.

Method Code URL notes
FixMatch github.com/google-research/fixmatch original
github.com/kekmodel/FixMatch-pytorch PyTorch version
MixMatch github.com/google-research/mixmatch original
github.com/YU1ut/MixMatch-pytorch PyTorch version
Realistic SSL Eval. github.com/perrying/realistic-ssl-evaluation-pytorch
Table E.1: Code repositories that we built upon to perform our experiments and verify the quality of results.

e.2 Hyperparameters for CIFAR-10/ CIFAR-100

Table E.2 lists the experimental settings (dataset sizes, etc.) and hyperparameters used for all CIFAR-10/CIFAR-100 experiments. We did not tune any hyperparameters specifically for Fix-A-Step.

BASIC SETTINGS CIFAR-10 BASIC SETTINGS CIFAR-100

TRAIN LABELED SET SIZE 2400/300 TRAIN UNLABELED SET SIZE 16400/17800 VALIDATION SET SIZE 3000 TEST SET SIZE 6000

TRAIN LABELED SET SIZE 5000 TRAIN UNLABELED SET SIZE 17500 VALIDATION SET SIZE 2500 TEST SET SIZE 5000
Labeled only VAT

Labeled Batch size 64 Learning rate 3e-3 Weight decay 2e-3

Labeled batch size 64 Unlabeled batch size 64 Learning rate 3e-2 Weight decay 4e-5 Max consistency coefficient 0.3 Unlabeled loss warmup iterations 419430 Unlabeled loss warmup schedule linear VAT 1e-6 VAT 6
Pseudo-label Mean Teacher

Labeled batch size 64 Unlabeled batch size 64 Learning rate 3e-2 Weight decay 5e-4 Max consistency coefficient 1.0 Unlabeled loss warmup iterations 419430 Unlabeled loss warmup ischedule linear Pseudo-label threshold 0.95

Labeled batch size 64 Unlabeled batch size 64 Learning rate 3e-2 Weight decay 5e-4 Max consistency coefficient 50.0 Unlabeled loss warmup iterations 419430 Unlabeled loss warmup schedule linear
Pi-Model MixMatch

Labeled batch size 64 Unlabeled batch size 64 Learning rate 3e-2 Weight decay 5e-4 Max consistency coefficient 10.0 Unlabeled loss warmup iterations 419430 Unlabeled loss warmup schedule linear

Labeled batch size 64 Unlabeled batch size 64 Learning rate 3e-2 Weight decay 4e-5 Max consistency coefficient 75.0 Unlabeled loss warmup iterations 1048576 Unlabeled loss warmup schedule linear Sharpening temperature 0.5 Beta shape 0.75
FixMatch

Labeled batch size 64 Unlabeled batch size 448 Learning rate 3e-2 Weight decay 5e-4 Max consistency coefficient 1.0 Unlabeled loss warmup iterations No warmup Unlabeled loss warmup schedule No warmup Sharpening temperature 1.0 Pseudo-label threshold 0.95
Table E.2: Hyperparameters used for CIFAR experiments. All settings represent the recommended defaults suggested in implementations by original authors for the 400 examples/class setting. We did not tune any hyperparameters specifically for Fix-A-Step.

e.3 Hyperparameters for Heart2Heart

Hyper-parameters are only tuned for the supervised-only baseline and the non Fix-A-Step version of the Pi-model, VAT and FixMatch. We ran 100 trials222in practice, for each trial we train for only 180 epochs to speed up the hyper-parameters selection process of Tree-structured Parzen Estimator (TPE) based black box optimization using an open source AutoML toolkit333https://github.com/microsoft/nni for each algorithm and each data split. The chosen hyper-parameters are then directly applied to Fix-A-Step without retuning. After hyper-parameter selection, each algorithm is then trained for 1000 epochs, the balanced test accuracy at maximum validation balanced accuracy is then reported.

Labeled-only:

we search learning rate in , weight decay in , optimizer in , learning rate schedule in . Batch size is set to 64.

Pi-model:

We search learning rate in , weight decay in , optimizer in , learning rate schedule in , Max consistency coefficient in , unlabeled loss warmup iterations in . Labeled batch size is set to 64 and unlabeled batch size is set to 64.

Vat:

We search learning rate in , weight decay in , optimizer in , learning rate schedule in , Max consistency coefficient in , unlabeled loss warmup iterations in . Labeled batch size is set to 64, unlabeled batch size is set to 64. is set to 0.000001 and is set to 6.

FixMatch:

We search learning rate in , weight decay in , optimizer in , learning rate schedule in , Max consistency coefficient in , Labeled batch size is set to 64, unlabeled batch size is set to 320. We set sharpening temperature to 1.0 and pseudo-label threshold is set to 0.95 (as in CIFAR experiments).

e.4 Labeled loss implementation: Weighted cross entropy

On many realistic SSL classification tasks, even the labeled set will have noticeably imbalanced class frequencies. For example, in the TMED-2 view labels, the four view types (PLAX, PSAX, A4C, A2C) differ in the number of available examples, with the rarest class (A2C) roughly 3x less common than the most common class (PLAX). To counteract the effect of class imbalance, we use weighted cross-entropy for labeled loss, following prior works (Huang et al., 2021b; Wu et al., 2021). Let integer index the classes in the labeled set, and let denote the number of images for class . Then when we compute the labeled loss , we assign a weight to the true class that is inversely proportional to the number of images of the class in the training set:

(5)

Here denotes the integer index of the true class corresponding to image , denotes the neural network weight parameters, and denotes the -th entry of the softmax output vector produced by the neural network classifier.

e.5 Cosine-annealing of learning rate.

We found that several baselines were notably improved using the cosine-annealing schedule of learning rate suggested by (Sohn et al., 2020). Cosine-annealing sets the learning rate at iteration to , where is the initial learning rate, and is the total iterations.

To be extra careful, we tried to allow all open-set/safe SSL baselines to also benefit from cosine annealing.

  • MTCF is trained using Adam following the author’s implementation Yu et al. (2020a). Although the author did not originally. use cosine learning rate schedule, we found that adding cosine learning rate schedule substantially improve MTCF’s performance. We thus report the performance for MTCF with cosine annealing.

  • DS3L is trained using Adam following the author’s implementation Guo et al. (2020). We tried to add Cosine learning rate to DS3L but result in worse performance. We thus report the performance for DS3L without cosine learning rate.