CLAF: Contrastive Learning with Augmented Features
for Imbalanced Semi-Supervised Learning
Abstract
Due to the advantages of leveraging unlabeled data and learning meaningful representations, semi-supervised learning and contrastive learning have been progressively combined to achieve better performances in popular applications with few labeled data and abundant unlabeled data. One common manner is assigning pseudo-labels to unlabeled samples and selecting positive and negative samples from pseudo-labeled samples to apply contrastive learning. However, the real-world data may be imbalanced, causing pseudo-labels to be biased toward the majority classes and further undermining the effectiveness of contrastive learning. To address the challenge, we propose Contrastive Learning with Augmented Features (CLAF). We design a class-dependent feature augmentation module to alleviate the scarcity of minority class samples in contrastive learning. For each pseudo-labeled sample, we select positive and negative samples from labeled data instead of unlabeled data to compute contrastive loss. Comprehensive experiments on imbalanced image classification datasets demonstrate the effectiveness of CLAF in the context of imbalanced semi-supervised learning.
Index Terms— imbalance, semi-supervised learning, contrastive learning, feature augmentation
1 Introduction
Semi-supervised learning (SSL) has attracted much attention in recent years, owing to its potential to mitigate the demand for labeled data by leveraging unlabeled data. The primary challenge in SSL lies in learning valuable information from a large amount of unlabeled data. Representation learning empowers the capture of rich insights from labeled data, thereby reducing the difficulty of utilizing unlabeled data. Contrastive learning is an effective way to learn strong visual representations in an unsupervised manner and has been extended to supervised learning [1], making it a promising approach for integration into SSL. A general pipeline of incorporating contrastive learning into SSL involves producing pseudo-labels for unlabeled data and utilizing them in a manner of pseudo-label-based contrastive learning (PCL). For a pseudo-labeled sample, PCL selects unlabeled samples sharing the same pseudo-label as positive samples and regards unlabeled samples with different pseudo-labels as negative samples. The central idea of PCL is to bring positive samples closer while pushing negative samples further apart. Through the integration of PCL, most of the existing SSL algorithms have achieved exceptional performance [2, 3, 4].
![Refer to caption](x1.png)
Although contrastive learning has demonstrated its efficacy in learning strong representations under SSL, these algorithms often assume class-balanced data, while many real-world data exhibit imbalanced distributions. Contrastive learning faces the risk of biased pseudo-labels and scarcity of minority class samples under imbalanced SSL. With class-imbalanced data, the class distribution of pseudo-labels from unlabeled data tends to exhibit towards the majority classes due to the confirmation bias [5]. Many pseudo-labels of majority classes are assigned to unlabeled samples that may not genuinely belong to those classes. Methods incorporating PCL tend to cluster instances with the same pseudo-labels from a specific majority class, potentially contradicting the actual relationships among unlabeled data. Additionally, the scarcity of minority class samples results in relatively poor representations of minority classes. These problems significantly constrain the representation learning capacity of contrastive learning in imbalanced SSL. In essence, the imbalanced data distribution leads to inaccurate pseudo-labels, subsequently undermining the precision of positive and negative samples.
In this paper, we propose a method called Contrastive Learning with Augmented Features (CLAF) devised to tackle the aforementioned challenges. First, we design a class-dependent feature augmentation module to alleviate the scarcity of labeled data in minority classes. Second, in contrast to conventional PCL that exclusively selects sample pairs from unlabeled data, CLAF selects both positive and negative samples from labeled data for each pseudo-labeled sample to reduce the influence of biased pseudo-labels as shown in Fig. 1.
2 Related Works
Semi-supervised learning (SSL): SSL learns from labeled data in conjunction with a large number of unlabeled data. Pseudo-labeling is a widely used SSL method, which uses the model’s predictions to label data and retrains the model with the artificial labels [5]. FixMatch [6] integrates consistency regularization and pseudo-labeling to align the predictions between weakly and strongly augmented unlabeled images.
Contrastive learning under SSL: Previous contrastive-based SSL works are almost two-stage ones. SelfMatch [7] adopts contrastive learning to pre-train a backbone and then fine-tune it based on augmentation consistency regularization. Existing SSL methods that build upon FixMatch mostly utilize pseudo-labels for contrastive learning [3]. To make use of the features learned by different loss functions and class-specific priors, SsCL [4] adopts the pseudo-labeling strategy with cross-entropy loss and instance discrimination with contrastive loss, jointly optimizing the two losses with a shared backbone in an end-to-end way. To address the confirmation bias due to the noise contained in pseudo-labels, CCSSL [2] introduces a class-aware contrastive module and focuses learning on unlabeled samples with pseudo-labels.
3 Preliminary
3.1 Problem Setup
For a -class semi-supervised image classification task, we are given labeled data and unlabeled data to train a model comprising a feature encoder followed by a linear classifier , where and correspond to the parameters of and respectively. For labeled data, the prediction of a image is learned from (e.g., cross-entropy) and its label . For unlabeled data, a pseudo-label is utilized in unsupervised loss , where can be implemented via entropy [8] or consistency regularization [9], depending on the SSL methods adopted.
3.2 DASO
DASO [11] is a comprehensive framework for imbalanced SSL incorporating distribution-aware blending for both linear and semantic pseudo-labels. The linear and semantic pseudo-label, and are generated by passing through linear and similarity-based classifier respectively. Subsequently, the final pseudo-label is derived through the fusion of and and serves as the target in .
The linear pseudo-label is obtained by applying the softmax function to the output of the linear classifier: . The semantic pseudo-label is derived from a similarity-based classifier. Specifically, DASO constructs a set of class prototypes from and a queue where denotes a feature queue for class with a fixed size . The class prototype for each class can be obtained simply by averaging the feature points in the feature queue . DASO measures the per-class similarity between a feature point and class prototypes:
(2) |
where represents cosine similarity and is a temperature hyper-parameter. To prevent an imbalanced prototype representation arising from class-imbalanced labeled data, DASO fixes the size of for all classes to the same amount, which can compensate for the prototypes of the minority classes with earlier samples remaining in the queue. To stabilize the movement of class prototypes in feature space during training, DASO employs a momentum encoder with the same architecture as , where is the exponential moving average (EMA) of with momentum ratio : .
4 Method
4.1 Class-dependent Feature Augmentation
DASO introduces a balanced queue to ensure equilibrium between minority and majority class samples. Notably, a significant portion of minority class features in the queue is generated from the same labeled data. To enhance data diversity and alleviate the scarcity of labeled data in minority classes, we employ feature augmentation (FA) within a batch to increase the count of labeled features for minority classes by blending unlabeled data features with labeled data features while preserving the label of the original labeled sample, which is inspired by [12, 13, 14]. The augmented feature is generated as:
(3) |
where and . is the mixture coefficient sampled from a Beta distribution denoted as . To ensure the validity of the label for the augmented feature, we consider with a value at least : . The FA is applied with a probability that depends on the count of labeled data for each class. Consequently, the more labeled data a class has, the less augmented feature is synthesized. Formally, given a labeled sample from class , we apply FA with probability defined as:
(4) |
where is the number of samples of class and is the number of samples of the class with the most labeled data. The class-dependent probability encourages more augmented features for minority classes.
We perform concurrent updates of for all classes by pushing new labeled features and augmented features within the batch and removing the oldest ones when is full.
4.2 Contrastive Learning with Augmented Features
To reduce the impact of biased pseudo-labels and utilize unlabeled data, we apply contrastive learning using both unlabeled and labeled data. For an unlabeled sample with a pseudo-label, we bring it close to labeled samples sharing the same label as the pseudo-label and push it away from labeled samples with different labels from the pseudo-label.
Following the common approaches in contrastive learning [15], we adopt the encoder-projection head structure in our method. Both raw feature and augmented feature are passed through the projection head to obtain corresponding embedding . We construct an extra embedding queue to store embeddings for features with labels, which is updated simultaneously with the feature queue . For unlabeled samples, we establish a confidence vector based on the confidence scores of the model’s predictions. Each element in is defined as:
(5) |
where is the index of the unlabeled sample. Given the presence of embeddings from augmented features, we construct a label confidence vector based on the mixture coefficient:
(6) |
where is the index of embedding in the embedding queue. To measure the weights for positive pairs in contrastive loss function, we obtain a weight matrix by multiplying elements of and . Each element in is defined as , where and represent the indices of unlabeled samples in a batch and embeddings in the embedding queue of the pseudo-label class. The contrastive loss can be defined as:
(7) |
where is the batch size of unlabeled samples. has the following format:
(8) |
where denotes the embedding queue of the pseudo-label class and represents the capacity of . and are embeddings from and respectively. is the temperature hyper-parameter. We calculate total loss using a weighted sum of supervised loss , semi-supervised loss , semantic alignment loss and contrastive loss . The final CLAF objective is as below:
(9) |
where both and with come from the base SSL learner, and is introduced from DASO. is the weight for contrastive loss.
5 Experiments
CIFAR10-LT | CIFAR100-LT | |||||||
Algorithms | ||||||||
Supervised | 47.30.95 | 61.90.41 | 44.20.33 | 58.20.29 | 29.60.57 | 46.90.22 | 25.11.14 | 41.20.15 |
w/ LA [16] | 53.30.44 | 70.60.21 | 49.50.40 | 67.10.78 | 30.20.44 | 48.70.89 | 26.51.31 | 44.10.42 |
FixMatch [6] | 67.81.13 | 77.51.32 | 62.90.36 | 72.41.03 | 45.20.55 | 56.50.06 | 40.00.96 | 50.70.25 |
w/ DARP [17] | 74.50.78 | 77.80.63 | 67.20.32 | 73.60.73 | 49.40.20 | 58.10.44 | 43.40.87 | 52.20.66 |
w/ CReST+ [18] | 76.30.86 | 78.10.42 | 67.50.45 | 73.70.34 | 44.50.94 | 57.40.18 | 40.11.28 | 52.10.21 |
w/ DASO [11] | 76.00.37 | 79.10.75 | 70.11.81 | 75.10.77 | 49.80.24 | 59.20.35 | 43.60.09 | 52.90.42 |
w/ CLAF (Ours) | 76.40.46 | 80.60.65 | 72.00.74 | 75.90.29 | 50.90.11 | 59.80.29 | 44.50.83 | 54.10.28 |
FixMatchLA [16] | 75.32.45 | 82.00.36 | 67.02.49 | 78.00.91 | 47.30.42 | 58.60.36 | 41.40.93 | 53.40.32 |
w/ DASO [11] | 77.90.88 | 82.50.08 | 70.11.68 | 79.02.23 | 50.70.51 | 60.60.71 | 44.10.61 | 55.10.72 |
w/ CLAF (Ours) | 78.80.59 | 83.10.32 | 72.81.39 | 79.30.33 | 51.10.25 | 60.90.22 | 46.10.19 | 55.60.51 |
5.1 Experimental Setup
5.1.1 Datasets
Following common practice [11], we create CIFAR10-LT and CIFAR100-LT for imbalanced SSL by exponentially decreasing the count of images from the head class to the tail class. We denote the head class size as and the imbalance ratio as . We set for labeled data and for unlabeled data. For common settings [11], we set , and , for CIFAR10-LT, and , and , for CIFAR100-LT. We report results of imbalance ratio and for CIFAR10-LT and and for CIFAR100-LT.
5.1.2 Training and evaluation
We conduct experiments under the same codebase with DASO [11] for fair comparison. We adopt Wide ResNet-28-2 [19] as our backbone on CIFAR10-LT and CIFAR100-LT. We apply FA in the last 20% iterations and set to 0.8 to meet the requirements of FA for structured representation space. and are set to 1.0 and 0.07 for all experiments. All hyper-parameters and training details follow DASO [11]. For evaluation, we use the EMA network with parameters updating every training step [11]. We measure the top-1 accuracy on test images every 500 iterations and report the median of the accuracy of the last 20 evaluations. We report the mean and standard deviation of three independent runs.
5.2 Results on CIFAR10/CIFAR100-LT
We report the results of CLAF on CIFAR10-LT and CIFAR100-LT under various settings in Table. 1. We compare CLAF with DARP [17], CReST+ [18] and DASO [11] on FixMatch. The results indicate CLAF achieves superior accuracy compared with baselines on different benchmarks. The results of different methods on re-balancing FixMatch via LA [16] show CLAF can benefit from debiasing pseudo-labels. It is noticeable that CLAF always exhibits performance improvements over DASO in all cases, which verifies the effectiveness of CLAF in representation learning under imbalanced SSL.
5.3 Ablation Study
We perform ablation studies on CIFAR10-LT and investigate the impact of FA. We report the results of CLAF and CLAF without FA in Table. 2. As previously discussed, FA mainly contributes to augmenting features for minority classes and providing minority class features for contrastive learning. The performance gap between CLAF and CLAF without FA indicates that naive contrastive learning brings marginal improvements and FA is beneficial for contrastive learning in imbalanced SSL.
Algorithm | ||||
---|---|---|---|---|
CLAF | 76.40.46 | 80.60.65 | 72.00.74 | 75.90.29 |
w/o FA | 76.10.25 | 79.90.24 | 70.82.15 | 75.50.41 |
5.4 Analysis
To assess the representation learning capacity of contrastive learning, we present t-SNE [20] visualization of CIFAR10-LT test data features obtained from DASO and CLAF. As shown in Fig. 2, tail class features in CLAF exhibit distinct decision boundaries while they are close to majority class features in DASO. CLAF achieves the accuracy of for the 3-least common classes, which is better than in DASO. The results suggest that CLAF has superior representations for minority classes compared to DASO.
![Refer to caption](x2.png)
6 Conclusion
We propose Contrastive Learning with Augmented Features (CLAF) to apply contrastive learning in imbalanced SSL. We design a class-dependent feature augmentation module to alleviate the scarcity of minority class samples. In contrast to conventional PCL, we select positive and negative samples from labeled data to reduce the impact of biased pseudo-labels. Our experimental results demonstrate that CLAF outperforms the baselines on imbalanced image datasets under various settings, confirming that CLAF exhibits a remarkable capacity for representation learning in imbalanced SSL.
References
- [1] Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan, “Supervised contrastive learning,” in Advances in Neural Information Processing Systems, 2020.
- [2] Fan Yang, Kai Wu, Shuyi Zhang, Guannan Jiang, Yong Liu, Feng Zheng, Wei Zhang, Chengjie Wang, and Long Zeng, “Class-aware contrastive semi-supervised learning,” in Conference on Computer Vision and Pattern Recognition, 2022, pp. 14401–14410.
- [3] Junnan Li, Caiming Xiong, and Steven C. H. Hoi, “Comatch: Semi-supervised learning with contrastive graph regularization,” in International Conference on Computer Vision, 2021, pp. 9455–9464.
- [4] Yuhang Zhang, Xiaopeng Zhang, Jie Li, Robert C. Qiu, Haohang Xu, and Qi Tian, “Semi-supervised contrastive learning with similarity co-calibration,” IEEE Trans. Multim., vol. 25, pp. 1749–1759, 2023.
- [5] Eric Arazo, Diego Ortego, Paul Albert, Noel E. O’Connor, and Kevin McGuinness, “Pseudo-labeling and confirmation bias in deep semi-supervised learning,” in International Joint Conference on Neural Networks, 2020, pp. 1–8.
- [6] Kihyuk Sohn, David Berthelot, Nicholas Carlini, Zizhao Zhang, Han Zhang, Colin Raffel, Ekin Dogus Cubuk, Alexey Kurakin, and Chun-Liang Li, “Fixmatch: Simplifying semi-supervised learning with consistency and confidence,” in Advances in Neural Information Processing Systems, 2020.
- [7] Byoungjip Kim, **ho Choo, Yeong-Dae Kwon, Seongho Joe, Seungjai Min, and Youngjune Gwon, “Selfmatch: Combining contrastive self-supervision and consistency for semi-supervised learning,” CoRR, vol. abs/2101.06480, 2021.
- [8] Yves Grandvalet and Yoshua Bengio, “Semi-supervised learning by entropy minimization,” in Advances in Neural Information Processing Systems, 2004, pp. 529–536.
- [9] Antti Tarvainen and Harri Valpola, “Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results,” in Advances in Neural Information Processing Systems, 2017, pp. 1195–1204.
- [10] Ekin Dogus Cubuk, Barret Zoph, Jonathon Shlens, and Quoc Le, “Randaugment: Practical automated data augmentation with a reduced search space,” in Advances in Neural Information Processing Systems, 2020.
- [11] Youngtaek Oh, Dong-** Kim, and In So Kweon, “DASO: distribution-aware semantics-oriented pseudo-label for imbalanced semi-supervised learning,” in Conference on Computer Vision and Pattern Recognition, 2022, pp. 9776–9786.
- [12] Yue Fan, Dengxin Dai, Anna Kukleva, and Bernt Schiele, “Cossl: Co-learning of representation and classifier for imbalanced semi-supervised learning,” in Conference on Computer Vision and Pattern Recognition, 2022, pp. 14554–14564.
- [13] Hongyi Zhang, Moustapha Cissé, Yann N. Dauphin, and David Lopez-Paz, “mixup: Beyond empirical risk minimization,” in International Conference on Learning Representations, 2018.
- [14] Han-Jia Ye, De-Chuan Zhan, and Wei-Lun Chao, “Procrustean training for imbalanced deep learning,” in International Conference on Computer Vision, 2021, pp. 92–102.
- [15] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton, “A simple framework for contrastive learning of visual representations,” in Proceedings of the 37th International Conference on Machine Learning, 2020, vol. 119, pp. 1597–1607.
- [16] Aditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat, Himanshu Jain, Andreas Veit, and Sanjiv Kumar, “Long-tail learning via logit adjustment,” in International Conference on Learning Representations, 2021.
- [17] Jaehyung Kim, Youngbum Hur, Sejun Park, Eunho Yang, Sung Ju Hwang, and **woo Shin, “Distribution aligning refinery of pseudo-label for imbalanced semi-supervised learning,” in Advances in Neural Information Processing Systems, 2020.
- [18] Chen Wei, Kihyuk Sohn, Clayton Mellina, Alan L. Yuille, and Fan Yang, “Crest: A class-rebalancing self-training framework for imbalanced semi-supervised learning,” in Conference on Computer Vision and Pattern Recognition, 2021, pp. 10857–10866.
- [19] Sergey Zagoruyko and Nikos Komodakis, “Wide residual networks,” in Proceedings of the British Machine Vision Conference, 2016.
- [20] Laurens Van der Maaten and Geoffrey Hinton, “Visualizing data using t-sne.,” Journal of machine learning research, vol. 9, no. 11, 2008.