ProCo: Prototype-aware Contrastive Learning for Long-tailed Medical Image Classification

Zhixiong Yang* 1Xiaohe Healthcare, ByteDance, Guangzhou, China 1    Junwen Pan* 1Xiaohe Healthcare, ByteDance, Guangzhou, China 1    Yanzhan Yang 1Xiaohe Healthcare, ByteDance, Guangzhou, China 1    Xiaozhou Shi 1Xiaohe Healthcare, ByteDance, Guangzhou, China 1    Hong-Yu Zhou 2Department of Computer Science, The University of Hong Kong, Pokfulam, Hong Kong

2{biancheng, zc.zhang}@bytedance.com
   Zhicheng Zhang 1Xiaohe Healthcare, ByteDance, Guangzhou, China 1    Cheng Bian 1Xiaohe Healthcare, ByteDance, Guangzhou, China 1
Abstract

Medical image classification has been widely adopted in medical image analysis. However, due to the difficulty of collecting and labeling data in the medical area, medical image datasets are usually highly-imbalanced. To address this problem, previous works utilized class samples as prior for re-weighting or re-sampling but the feature representation is usually still not discriminative enough. In this paper, we adopt the contrastive learning to tackle the long-tailed medical imbalance problem. Specifically, we first propose the category prototype and adversarial proto-instance to generate representative contrastive pairs. Then, the prototype recalibration strategy is proposed to address the highly imbalanced data distribution. Finally, a unified proto-loss is designed to train our framework. The overall framework, namely as Prototype-aware Contrastive learning (ProCo), is unified as a single-stage pipeline in an end-to-end manner to alleviate the imbalanced problem in medical image classification, which is also a distinct progress than existing works as they follow the traditional two-stage pipeline. Extensive experiments on two highly-imbalanced medical image classification datasets demonstrate that our method outperforms the existing state-of-the-art methods by a large margin. Our source codes are available at https://github.com/skyz215/ProCo.

Keywords:
Contrastive learning Prototype Imbalanced dataset
footnotetext: Zhixiong Yang and Junwen Pan contributed equally to this work.footnotetext: Cheng Bian and Zhicheng Zhang are corresponding authors.

1 Introduction

Convolution neural network has been proved to be successful in many visual tasks [11, 23, 20, 25, 12, 13]. Although impressive breakthrough has been achieved, recent advances are still driven by the balanced dataset in the corresponding tasks. However, in real-world medical practice, the acquired dataset often exhibits long-tail distribution [19], where the head categories dominate most of the data, whereas tailed categories only have a handful of samples. Such skewed dataset is commonly originated from the difficulty of collecting rare diseases, or insufficient annotation from the proficient expertise. For this reason, the well-trained model is prone to make decision bias towards head classes due to the numerical superiority, weakening the model performance on tailed classes.

To address the long-tailed imbalance, follow-up studies have been conducted in recent years. Common solutions include class-rebalancing [27, 8], information supplement [29, 32], and module modification [15, 34, 14]. Unfortunately, these works are heavily depended on manual designs or prior knowledge, causing low efficacy and poor generalization of their proposed models. To this end, we propose Prototype-aware Contrastive learning (ProCo) to address the long-tailed problem with decent performance and high efficiency.

Different from previous works, with the assistance of well-designed loss function Proto-loss, the main innovation of ProCo is that the proposed framework is a combination of the contrastive learning, category prototype, and proto-instance and can commendably tackle the long-tailed medical image classification. Formally, our technical contributions can be summarized in four-fold:

  • We propose a category prototype and adversarial proto-instance for feature modeling. Specifically, the category prototype can model the arbitrary category distribution adaptively. Especially, adversarial proto-instance is generated from category prototype and representative instance to enhance the robustness of contrastive learning over all classes in the long-tailed setting.

  • We present a prototype recalibration strategy to ensure the updated frequency on category prototypes of tailed classes and eliminate the prototype bias, which is an imbalance updating process resulting in an incorrect distribution prediction.

  • To unify the contrastive learning together with prototype-based supervised learning, we propose a proto-loss, which can significantly boost the efficiency of our end-to-end framework ProCo.

  • Extensive experiments on two long-tailed medical classification datasets show that ProCo yields the best over state-of-the-arts by a large margin, demonstrating the effectiveness of the proposed method.

2 Related Work

Existing practices on the class-rebalancing aim to adjust the distribution of training samples to achieve balanced data over all categories, e.g., over-sampling [17] on tailed classes or under-sampling [16] on head classes. While these techniques have successfully improved the performance on tailed classes, the performance on head classes will be sacrificed [33]. In this regards, information supplement methods are investigated by introducing prior expertise into the entire framework to overcome the performance degradation, e.g., transfer learning [30], model pre-training [31], knowledge distillation [28], and self-supervised learning [29]. Other studies explore module modification paradigms. For instance, Kang et al. [15] proposed a two-stage decoupling training method, in which the backbone and classifier were trained separately for the tailed class accommodation. Dong et al. [7] proposed a metric learning approach for batch incremental hard sample mining of minority attribute classes from imbalanced large-scale training data. However, one limitation of these methods is that the prior expertise relies on cumbersome manipulation, which is an obstacle to real-world application. In this paper, we attempt to address the long-tail problem in a contrastive way, which can not only tackle above challenges but also achieve promising performance with competence of high efficiency.

3 Methodology

Our proposed Prototype-aware Contrastive Learning (ProCo) framework.
Figure 1: Our proposed Prototype-aware Contrastive Learning (ProCo) framework. and are two projectors, is the image feature, is the query feature. Prototypes and sample queue are used to generate adversarial proto-instances. All negatives and positives enter into the proposed proto-loss for optimization. Besides, a prototype recalibration strategy is used for adjusting the weight of each prototype in proto-loss.

Fig. 1 demonstrates the diagram of our proposed ProCo framework. It originates from MoCo [10], which consists of an online encoder and a momentum updated encoder. Here we omit the momentum encoder for simplicity. As these two encoders share the same architecture, in the following text we only describe the structure of the online encoder.

3.0.1 Notation

The long-tailed classification training set with training samples and categories is denoted as . The network consists of a shared backbone and two projection heads notated as and , respectively. The feature produced by projection head is used for classification, while for the contrastive learning. The sample queue with instances is maintained via the momentum fashion [10].

3.1 Category Prototype and Adversarial Proto-instance

Illustration of the proposed adversarial proto-instance.
Figure 2: Illustration of the proposed adversarial proto-instance.

Classic contrastive training pairs (i.e., positive and negative pairs) are used to learn the representation of instances. However, in the long-tailed dataset, the head classes dominate most of negative pairs via the conventional contrastive methods, causing the under-learning of tailed classes. Previous works [2, 24] reveal that not all negative pairs facilitate the contrastive learning. Therefore, the key to improve the performance of the long-tailed problem is to reduce the redundancy of contrastive negative pairs and mine recognizable positive pairs. To this end, we propose a new concept named category prototype. The prototype is a set of learnable parameters for predefined categories and is optimized by our proto-loss as described in Sec. 3.3. Then, we generate the adversarial proto-instance from the category prototype and representative sample, i.e., confusing samples from the alternative sample queue, via a linear interpolation in the feature space, which will be utilized to form a training pair in ProCo. Theoretically, the adversarial proto-instance is designed as a special outlier, which can encourage ProCo to rectify the decision boundaries of the tailed categories during the contrastive learning.

Fig. 2 illustrates the diagram of the adversarial proto-instance. First, for each instance , the queue is divided into two disjoint subsets and that comprise positive and negative instances, respectively. Similarly, is also grouped into a singleton set with exactly one prototype and a negative set , where is the positive prototype feature corresponding to the current instance. Then, the adversarial positive proto-instances are derived from and while the negative ones are from and .

To synthesize the adversarial negative proto-instances, we prioritize those negative instances that are likely to be confused with the current instance. The distance between the current feature and negative feature can be utilized as an indicator, and is represented as:

(1)

where the superscript represents the transpose operation. Then, we rank the negative instances in in ascending order according to their distance to and select the top instances to compose the adversarial proto-instance set :

(2)

where is the -th element in the sorted . Finally, for more challenging negative instances, we randomly perturb each element in with the positive prototype :

(3)

where is a random interpolation coefficient for each sample and the upper bound is a hyperparameter with a small value. We assume that prototype always contributes less than to the generated proto-instance, which guarantees that the negative semantic within can be held.

As for the adversarial positive proto-instances, the strategy is to select samples in that are misclassified and combine them with the incorrectly assigned prototype, where the interpolation manner is identical to that of Eq. 3.

3.2 Prototype Recalibration

Although the problem of constructing contrastive pairs has been addressed by the proposed category prototype and adversarial proto-instance, the learning of category prototype is still potentially affected by the class imbalance problem. We argue that the underlying reason is the prototype bias, where the generated proto-instance will incline to head classes, jeopardizing the performance on tailed classes. For this reason, this work proposes a prototype recalibration strategy, which estimates the representative level of each category prototype showing the importance of tailed class prototype features.

Specifically, inspired by the similarity between the projected features and prototype, we introduce a rectified sigmoid function to achieve recalibration. In formal, the calibration factor for each category prototype is defined as:

(4)

where is a subset of samples in the associated category , and is the total number of them. \colorblue To allow end-to-end calibration for batch-based training, we keep a running mean via the exponential moving average for the global calibration factor:

(5)

where are samples in a batch with label , and is a smoothing coefficient.

The calibration factor reflects the difficulty of each category and the representativeness of the corresponding prototype in the model learning. Eventually, we impose all prototypes in by the calibration factors as:

(6)

3.3 Proto-loss for Training

To integrate the contrastive learning into our work in an end-to-end manner, we refer to the concept of InfoNCE[22]. Unfortunately, it only supports single positive pair, which is incompatible to our work. Therefore, inspired by the unified contrastive loss [6], we extend it to include prototypes and involve both positive and negative adversarial proto-instance for training, which ensures the optimization consistency of supervised and contrastive training so as to achieve decent performance compared with the former studies.

Formally, considering as negative set, and as positive set. For an instance , the proto-loss is formulated as:

(7)

Note that the proposed proto-loss can also be applied to general contrastive settings.

4 Experiments

4.1 Datasets and Evaluation

We conduct the experiment on two publicly available datasets. The ISIC2018 is accessed by the Skin Lesion Analysis Toward Melanoma Detection 2018 challenge [5]. The other dataset is APTOS2019, which is provided by  [1]. For a better illustration, the details of datasets are listed in Table 1. Notably, the imbalance ratio denotes as , where is the number of samples in each class. We follow the same protocol in [21] and randomly split the original dataset into train and test sets with the ratio of 7:3. All experimental results will be reported with the criterion of accuracy and F1-score.

Dataset # of classes # of samples Imbalance ratio
ISIC2018 7 10015 58
APTOS2019 5 3662 10
Table 1: The details of long-tailed medical datasets.

In this work, we selected 9 existing methods as the comparison methods. To be specific, we use cross-entropy (CE) as the baseline. Further, based on CE, the balanced resampling strategy is used to address the imbalanced classification problem. The contrastive learning method can be treated as another comparison method. By integrating the balanced resampling strategy,“CL+resample” is proposed in [21], which is a two-stage approach by training the backbone firstly, then freezing the backbone and training the classifier with balanced sampling. To down-weight easy negatives in one-stage detector, focal loss [18] is also useful in classification problem. LDAM [3] focuses on label distribution margin, and is regarded as a simple but effective training strategy. OHEM [26] is a hard negative mining method based on the model. DANIL [9] explores distractors to learn better CNN features.

Methods ISIC2018 APTOS2019
Accuracy F1-score Accuracy F1-score
CE 0.850 0.716 0.812 0.608
CE+resample 0.861 0.735 0.802 0.583
Focal loss 0.849 0.728 0.815 0.629
LDAM 0.857 0.734 0.813 0.620
OHEM 0.818 0.660 0.813 0.631
MTL 0.811 0.667 0.813 0.632
DANIL 0.825 0.674 0.825 0.660
CL 0.865 0.739 0.825 0.652
CL+resample 0.868 0.751 0.816 0.608
ProCo(ours) 0.887 0.763 0.837 0.674
Table 2: Comparison with the state-of-the-art methods

4.2 Implementation Details

The data augmentation policy and update ratio we utilized of the contrastive learning is identical to MoCoV2 [4]. ResNet50 [11] is used as our backbone. We implement projector with 2 fully-connected layers, of which the hidden layer size is set to 2048, followed by ReLU activation function. The projector is realized by a single fully-connected layer with ReLU activation function. For simplicity, category prototype can be regarded as a classifier in our framework. The batch size is 128 and the default optimizer is SGD with a momentum of 0.9 and a weight decay of 0.0001. The initial learning rate is set to 0.05. The initial similarity value is set to 0.01 and is set to 0.95. Hyperparameters and are set at 20 and 0.4 via a grid search. Referring to [21], the training epochs of ISIC2018 and APTOS2019 are set at 1,000 and 2,000, respectively. Particularly in the test phase, we only utilize and to acquire the prediction.

4.3 Comparison with the State-of-the-art

In this part, we compare the proposed ProCo with the state-of-the-art methods on two open-release datasets: ISIC2018 and APTOS2019. Table 1 presents the entire experimental results. We can see that the proposed method achieves optimal performance regardless of the data set, demonstrating its excellent generalization. Apart from this observation, the performance of Focal loss and LDAM is comparable to the baseline on both datasets. In addition, OHEM, MTL, and DANIL obtain comparable accuracy and F1-score on the APTOS2019 while inferior performance on the ISIC2018, illustrating their weak generalization. The results from the two CL-based methods are uniform and outperform the baseline. Note that the role of the balanced resampling strategy is inconsistent across different data sets.

Proto-loss proto-instance prototype recalibration Accuracy F1-score
0.857 0.742
0.875 0.751
0.862 0.754
0.887 0.763
Table 3: Effectiveness of each module in our ProCo framework

4.4 Ablation Study

In this work, the proposed framework has three fundamental modules. To validate the effectiveness of each module, we carried out ablation studies as shown in Table 3. For this ablation study, four extra experiments have been designed by arranging and combining these three modules: 1) We discarded the proto-instance and prototype recalibration strategy and only used Proto-loss to train the proposed method. 2) Based on the above experiment setting, we introduced the proto-instance module to re-train the proposed method. 3) Integrating the prototype recalibration strategy into the first experiment setting. 4) Employing all the three modules which is our entire proposed framework. From Table 3, we can clearly observe that with all the three modules, the proposed method can obtain the best performance in terms of accuracy and F1-score as shown in the last row. To be specific, the modified framework with only the Proto-loss as the loss function obtained an inferior performance to other three experiments. To this end, we can benefit from proto-instance module and prototype recalibration strategy, which is consistent with the results from experiment 1 and 2. In addition, to evaluate the superiority of the Proto-loss over other commonly-used loss functions: cross-entropy (CE) and InfoNCE, we re-trained the proposed network using different loss functions, respectively. The final results of F1-score were: 0.716 (CE), 0.739 (InfoNCE), and 0.742 (Proto-loss). We can see that using Proto-loss as the loss function will improve the final F1-score by 3.63% compared to that from CE.

5 Conclusion

This paper proposes a novel paradigm called ProCo, addressing the long-tailed classification problem in a contrastive way. Our ProCo mainly consists of three components: i) category prototype and the adversarial proto-instance; ii) prototype recalibration strategy and iii) a unified proto-loss. Extensive experiments on two publicly available datasets show that the efficacy of our proposed components, and our proposed framework outperforms the existing state-of-the-art long-tailed methods by a large margin.

References

  • [1] (2019) Aptos 2019 blindness detection (2019). Note: https://www.kaggle.com/c/aptos2019-blindness-detection/data Cited by: §4.1.
  • [2] T. T. Cai, J. Frankle, D. J. Schwab, and A. S. Morcos (2020) Are all negatives created equal in contrastive instance discrimination?. arXiv preprint arXiv:2010.06682. Cited by: §3.1.
  • [3] K. Cao, C. Wei, A. Gaidon, N. Arechiga, and T. Ma (2019) Learning imbalanced datasets with label-distribution-aware margin loss. NeurIPS 32. Cited by: §4.1.
  • [4] X. Chen, H. Fan, R. Girshick, and K. He (2020) Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297. Cited by: §4.2.
  • [5] N. C. Codella, D. Gutman, M. E. Celebi, B. Helba, M. A. Marchetti, S. W. Dusza, A. Kalloo, K. Liopyris, N. Mishra, H. Kittler, et al. (2018) Skin lesion analysis toward melanoma detection: a challenge at the 2017 international symposium on biomedical imaging (isbi), hosted by the international skin imaging collaboration (isic). In ISBI, pp. 168–172. Cited by: §4.1.
  • [6] Z. Dai, B. Cai, Y. Lin, and J. Chen (2021) UniMoCo: unsupervised, semi-supervised and full-supervised visual representation learning. arXiv preprint arXiv:2103.10773. Cited by: §3.3.
  • [7] Q. Dong, S. Gong, and X. Zhu (2017) Class rectification hard mining for imbalanced deep learning. In ICCV, pp. 1851–1860. Cited by: §2.
  • [8] A. Estabrooks, T. Jo, and N. Japkowicz (2004) A multiple resampling method for learning from imbalanced data sets. Computational intelligence 20 (1), pp. 18–36. Cited by: §1.
  • [9] L. Gong, K. Ma, and Y. Zheng (2020) Distractor-aware neuron intrinsic learning for generic 2d medical image classifications. In miccai, pp. 591–601. Cited by: §4.1.
  • [10] K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick (2020) Momentum contrast for unsupervised visual representation learning. In CVPR, pp. 9729–9738. Cited by: §3.0.1, §3.
  • [11] K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In CVPR, pp. 770–778. Cited by: §1, §4.2.
  • [12] F. Isensee, J. Petersen, A. Klein, D. Zimmerer, P. F. Jaeger, S. Kohl, J. Wasserthal, G. Koehler, T. Norajitra, S. Wirkert, et al. (2018) Nnu-net: self-adapting framework for u-net-based medical image segmentation. arXiv preprint arXiv:1809.10486. Cited by: §1.
  • [13] W. Ji, S. Yu, J. Wu, K. Ma, C. Bian, Q. Bi, J. Li, H. Liu, L. Cheng, and Y. Zheng (2021-06) Learning calibrated medical image segmentation via multi-rater agreement modeling. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 12341–12351. Cited by: §1.
  • [14] B. Kang, Y. Li, S. Xie, Z. Yuan, and J. Feng (2020) Exploring balanced feature spaces for representation learning. In ICLR, Cited by: §1.
  • [15] B. Kang, S. Xie, M. Rohrbach, Z. Yan, A. Gordo, J. Feng, and Y. Kalantidis (2020) Decoupling representation and classifier for long-tailed recognition. In ICLR, Cited by: §1, §2.
  • [16] M. Koziarski (2020) Radial-based undersampling for imbalanced data classification. PR 102, pp. 107262. Cited by: §2.
  • [17] F. Last, G. Douzas, and F. Bacao (2017) Oversampling for imbalanced learning based on k-means and smote. arXiv preprint arXiv:1711.00837. Cited by: §2.
  • [18] T. Lin, P. Goyal, R. Girshick, K. He, and P. Dollár (2017) Focal loss for dense object detection. In ICCV, pp. 2980–2988. Cited by: §4.1.
  • [19] Z. Liu, Z. Miao, X. Zhan, J. Wang, B. Gong, and S. X. Yu (2019) Large-scale long-tailed recognition in an open world. In CVPR, pp. 2537–2546. Cited by: §1.
  • [20] J. Long, E. Shelhamer, and T. Darrell (2015) Fully convolutional networks for semantic segmentation. In CVPR, pp. 3431–3440. Cited by: §1.
  • [21] Y. Marrakchi, O. Makansi, and T. Brox (2021) Fighting class imbalance with contrastive learning. In miccai, pp. 466–476. Cited by: §4.1, §4.1, §4.2.
  • [22] A. v. d. Oord, Y. Li, and O. Vinyals (2018) Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: §3.3.
  • [23] S. Ren, K. He, R. B. Girshick, and J. Sun (2017) Faster R-CNN: towards real-time object detection with region proposal networks. IEEE TPAMI 39 (6), pp. 1137–1149. Cited by: §1.
  • [24] J. D. Robinson, C. Chuang, S. Sra, and S. Jegelka (2021) Contrastive learning with hard negative samples. In ICLR, Cited by: §3.1.
  • [25] O. Ronneberger, P. Fischer, and T. Brox (2015) U-net: convolutional networks for biomedical image segmentation. In miccai, pp. 234–241. Cited by: §1.
  • [26] A. Shrivastava, A. Gupta, and R. Girshick (2016) Training region-based object detectors with online hard example mining. In CVPR, pp. 761–769. Cited by: §4.1.
  • [27] T. Wang, Y. Li, B. Kang, J. Li, J. Liew, S. Tang, S. Hoi, and J. Feng (2020) The devil is in classification: a simple framework for long-tail instance segmentation. In ECCV, pp. 728–744. Cited by: §1.
  • [28] X. Wang, L. Lian, Z. Miao, Z. Liu, and S. X. Yu (2020) Long-tailed recognition by routing diverse distribution-aware experts. arXiv preprint arXiv:2010.01809. Cited by: §2.
  • [29] C. Wei, K. Sohn, C. Mellina, A. Yuille, and F. Yang (2021) Crest: a class-rebalancing self-training framework for imbalanced semi-supervised learning. In CVPR, pp. 10857–10866. Cited by: §1, §2.
  • [30] Z. Weng, M. G. Ogut, S. Limonchik, and S. Yeung (2021) Unsupervised discovery of the long-tail in instance segmentation using hierarchical self-supervision. In CVPR, pp. 2603–2612. Cited by: §2.
  • [31] Y. Yang and Z. Xu (2020) Rethinking the value of labels for improving class-imbalanced learning. NeurIPS 33, pp. 19290–19301. Cited by: §2.
  • [32] Y. Zang, C. Huang, and C. C. Loy (2021) Fasa: feature augmentation and sampling adaptation for long-tailed instance segmentation. In ICCV, pp. 3457–3466. Cited by: §1.
  • [33] Y. Zhang, B. Kang, B. Hooi, S. Yan, and J. Feng (2021) Deep long-tailed learning: a survey. arXiv preprint arXiv:2110.04596. Cited by: §2.
  • [34] B. Zhou, Q. Cui, X. Wei, and Z. Chen (2020) Bbn: bilateral-branch network with cumulative learning for long-tailed visual recognition. In CVPR, pp. 9719–9728. Cited by: §1.