11email: {okerinde,lshamir,bhsu,theis,nnafi}@ksu.edu
eGAN: Unsupervised approach to class imbalance using transfer learning
Abstract
Class imbalance is an inherent problem in many machine learning classification tasks. This often leads to learned models that are unusable for any practical purpose. In this study, we explore an unsupervised approach to address class imbalance by leveraging transfer learning from pre-trained image classification models. To this end, an encoder-based Generative Adversarial Network (eGAN) is proposed which modifies the generator of a GAN by introducing an encoder module and adopts the GAN loss function to directly classify the majority and minority class. To the best of our knowledge, this is the first work to tackle this problem using GAN-based loss function rather than augmenting the dataset with synthesized fake images. Our approach eliminates the epistemic uncertainty in the model predictions, as and need not sum up to 1. The impact of transfer learning and combinations of different pre-trained image classification models at the generator and the discriminator level is also explored. Best result of 0.69 F1-score was obtained on CIFAR-10 classification task with an enforced imbalance ratio of 1:2500. Our implementation code is available at -
https://github.com/demolakstate/eGAN_addressing_class_imbalance_with_transfer_learning_on_GAN.git.
Keywords: Class imbalance, Transfer Learning, GAN, nash equilibrium
1 INTRODUCTION
A dataset is considered imbalanced when there is a significant, or in some cases, extreme disproportion between the number of samples of the different classes in the dataset. The class or classes with large number of samples are called the majority, while the class with few examples are denoted as the minority. In many cases, the machine learning model is required to correctly classify the minority class while minimizing the misclassification of the majority class. However, the skewness in the data often leads machine learning classification methods to favour the majority class.
Class imbalance problem in computer vision is normally approached either at the data level or algorithm level. Using data augmentation, a class with a small number of samples can be expanded into a class with a much larger number of samples. Earlier data augmentation was achieved simply by transforming images via scaling, cropping, flipping, padding, rotation, brightness, contrast, saturation level etc [15]. Now-a-days, synthetic images can also be generated using generative models such as VAE, GAN [8][4]. As a result, a humongous image dataset can be created from the images of the minority class.
At the algorithm level, the objective function is tweaked to heavily penalize the network for mis-classifying the minority class [13] [10]. The most popular is cost-sensitive approach. Here, the classifier is modified to incorporate varying penalty for each of considered groups of examples. By assigning a higher cost to the less represented set of samples its importance is boosted during training.
Transfer learning has been known to help improve the performance of machine learning models [14]. By fine-tuning varying number of layers in the pre-trained image classification model, the pre-trained model can serve as a feature extractor, while adding a classifier head for more specific feature learning for the current task.
In this work, we compared the performance of various pre-trained image classification models for the task of unsupervised image classification with varying imbalance ratios. Our architecture, named eGAN, is developed to serve as a basis for this comparison. Using GAN [4], we reparameterise the task of the discriminator as a classifier which outputs a positive score for majority samples and a negative score for the minority ones. We integrate an encoder module to the GAN network that encodes the minority samples into a latent code from which the generator learns. While most GAN-based architectures focus on the output of the generator, in our proposed approach, as we intuitively adapted the vanilla GAN network and the corresponding loss function to directly classify the majority and the minority class, we are more concerned about the performance of the discriminator.
2 RELATED WORK
There have been a lot of work in the last few decades to address class imbalance. Earlier approaches include deliberate undersampling of the majority class or oversampling of the minority class by mere copying [3]. However, for image data the earlier approach leads to loss of useful data information while the latter approach causes overfitting [11]. Data augmentation via rotation, scaling, cropping etc can be considered as a variant of oversampling which copy the same data, however with little modification [15]. VAE and GAN enables the generation of completely new data [8][4]. In recent years, GAN-based approaches have gained much popularity than others and a good number of variants of vanilla GAN have been proposed to address class imbalance [12] [16] [1].
In [6], an ensemble method was proposed based on advanced generative adversarial network to generate new samples for the minority class to restore balance. Our opinion is that the computational demand of such approach is enormous, and many low-income countries of the world do not have access to such computation power. Deep Cascading (DC) with a long sequence of decision trees could help to handle unbalanced data [2]. A DC is a sequence of n classifiers where each sample x passes to the next classifier only if the current one classifies it as positive according to a high-sensitivity decision threshold. However, this works well with foreground-background imbalance unlike the classification task. Transfer learning with GAN was used to generate images from limited data in [14]. Their result showed that knowledge from pre-trained networks can ensure faster convergence and significantly improve the quality of generated images.
3 METHODOLOGY AND EXPERIMENTAL DESIGN
In this section, we discuss our proposed approach and the various testbeds that were used in our experiments.
3.1 ADDRESSING CLASS IMBALANCE WITH eGAN
The proposed architecture is based on adaptation of existing Deep Convolutional Generative Adversarial Network (DCGAN)[12] by incorporating an encoder module. This module encodes minority samples in latent space needed by the generator G to generate minority samples that are capable of fooling the discriminator D. On the other hand, the discriminator D is fed with data samples drawn from majority distribution and the generated output of the generator G. D and G are simultaneously optimized through the following two-player minimax game with value function V(G,D) in 1.
(1) |
where and are majority and minority sample distributions respectively.
Over the course of iteration, the discriminator D is optimized to assign a negative score to the minority data distribution and a positive score to the majority data distribution. This enables the discriminator D to act as a classifier.
3.1.1 Encoder-Generator module
Our latent space is composed of 128 units vector. Rather than feeding the generator with random noise as is typical of most GAN implementation, we added an encoder module that forces the generator to learn from known distribution (minority distribution). The encoder part consists of the pre-trained DenseNet121 followed by global average pooling layer and latent dimension space. The generator part has two transposed convolutional layers. We use LeakyRelu activation function with alpha set to 0.2; batch normalization and Sigmoid function at the final layer.
3.1.2 Discriminator module
The pre-trained discriminator has 7,038,529 parameters out of which only 39,937 are trainable. A layer of global average pooling follows the pre-trained DenseNet121. We use a dropout of 0.2 followed by the final one unit dense layer. The overall architecture of our encoder-based generative adversarial network is shown in Figure 1.

3.2 Selection of pre-trained image classification weights
We perform experiments on VGG16, VGG19, EfficientNetB2, ResNet101 and DenseNet121 pre-trained classification models on ImageNet dataset. Here, we fine-tuned only top five layers at each of the pre-trained models. Table 1 shows the maximum precision, recall and F1-score obtained on CIFAR-100 with imabalance ratio 1:50 by using different combinations of pre-trained models.
Discriminator | Generator | Precision | Recall | F1 |
---|---|---|---|---|
pre-trained | pre-trained | |||
ResNet101 | VGG19 | 0.72 | 1.0 | 0.78 |
VGG19 | ResNet101 | 0.73 | 1.0 | 0.69 |
EfficientNetB2 | VGG19 | 1.0 | 0.22 | 0.32 |
VGG19 | EfficientNetB2 | 0.7 | 0.86 | 0.71 |
ResNet101 | VGG16 | 1.0 | 0.17 | 0.27 |
DenseNet121 [5] was used for pretraining our eGAN. After experimenting different pre-trained architectures and different layers of fine-tuning, we obtained best result with fine-tuning only top 5-layer out of 427 layers of DenseNet121.
3.3 Dataset
Several commonly used datasets were used in this study. In order to model the real-world scenario of heavy imbalance, we used only few samples of the minority class as input to the encoder module. Detail overview is shown in Table 2.
Class imbalance | # training | # training | # testing | # testing |
ratio | minor | major | minor | major |
1:2500 | 2(-) | 5000(-) | 1000(-) | 1000(-) |
1:1000 | 5(-) | 5000(-) | 1000(-) | 1000(-) |
1:500 | 10(1) | 5000(500) | 1000(100) | 1000(100) |
1:50 | 100(10) | 5000(500) | 1000(100) | 1000(100) |
1:1 | 500(500) | 500(500) | 1000(100) | 1000(100) |
∗1:1 | -(1) | -(1) | -(100) | -(100) |
The CIFAR-10 dataset [9] consists of 60,000 3232 colour images in 10 classes, with 6000 images per class. There are 50,000 training images and 10,000 test images. Here we use airplane as minority and automobile as majority.
CIFAR-100 [9] is similar to CIFAR-10, except it has 100 classes containing 600 images each. Each class has 500 training images and 100 testing images. The 100 classes in the CIFAR-100 are grouped into 20 superclasses. Each image comes with a ”fine” label (the class in which it belongs) and a ”coarse” label (the superclass). We also use the pneumonia subset of Stanford CheXpert dataset [7] for experimenting on an inherent imbalanced dataset. The dataset contains 4576 and 167407 minority and majority samples, respectively.
4 RESULTS AND DISCUSSION
All discriminator scores that are less than zero are classified as minority, otherwise they are classified as majority class. Table 1 shows the result obtained by combining various pre-trained models. Adam optimizer with a learning rate of 1e-4 was used to train all models for 100 epochs.
As can be seen in Figure 2, the model achieved a Nash equilibrium on test data at around 10 epochs. Here, we perform inference on test data at every epoch and plot the number of samples correctly classified. Five layers of DenseNet121 were fine-tuned at each generator and discriminator module, while 422 layers’ weight were kept fixed. At Nash, the discriminator correctly classified roughly 700/1000 of each of the minority and majority test data. This result convinces us that transfer learning with GAN can be used to overcome the challenge of highly imbalanced dataset, owing to the fact that we train only with 10 samples of the minority class and 5000 samples of the majority class. Similar performance is observed in CIFAR-100 with imbalance ratio 1:50.




Without pre-training the discriminator, the effect of the high imbalance in the training set is revealed, as the discriminator is skewed towards the majority class in the training set, thereby missing all the minority samples in the test data. This can be seen in Figure 3 on CIFAR-100. This behaviour pattern is observed on CIFAR-10 as well. We experiment with no pre-training at all, neither in the discriminator nor generator, and observed exact same pattern. Therefore, we can safely conclude that the use of transfer learning helps unsupervised image classification in a highly imbalanced domain.

Training can be stopped as soon as Nash equilibrium is reached, as this point gives the model best performance on the minority and majority class. An acceptable threshold can also be set for the absolute difference of the number of correctly classified samples of both classes. For instance, if the correctly_classified_minority - correctly_classified_majority 20. The precision, recall and F1-score curves on CIFAR-10 averaged over five folds at different imbalance ratios are shown in Figure 2.
We observed that at the early training epochs, typically between 1 and 40 epochs, the generator tries to achieve its objective of fooling the discriminator by generating samples from the majority class fed into the discriminator. That results in more of the minority samples being mis-classified as the discriminator ”knows” the distribution of the majority too well. A drastic change occurs when the generator start generating samples from latent vector, which can fool the discriminator as seen in generator and discriminator loss shown in Figure 4.


4.1 Imbalance ratios
To eliminate bias in model performance, we conducted 5-fold-cross-validation on the minority samples and average the result.
4.1.1 Class ratio of 1:2500
We experiment on CIFAR-10 dataset by deliberately using an unbalanced subset of the training set. At the 4th epoch, our model correctly identifies 257 minority and 821 majority samples out of 1000 each. At epoch 5, a sharp change occurred that led to 821 minority samples being correctly classified, while only correctly classifying 233 majority samples as shown in Table 3. We also observed that at epoch 72 the performance of the network on the majority and minority classes reached a nash equilibrium with a threshold difference of less than or equal to 20.
4.1.2 Class ratio of 1:1000
At epoch 81 on CIFAR-10 dataset, nash equilibrium was reached. At this epoch, 532 and 525 minority and majority test data respectively were correctly classified. We observed that the classifier had another major shift between epoch 5 and 6. At epoch 5, best result was obtained. The network was able to classify 863 majority tests and 585 minority tests correctly out of the 1000 samples. At epoch 6, 634 majority and 810 minority tests were classified correctly as swown in Table 4. Maximum precision, F1-score and recall of 0.88, 0.74 and 0.99 were obtained at epochs 3, 6 and 37, respectively.
Predicted minority | Predicted majority | |
---|---|---|
Actually minority | 821 | 179 |
Actually majority | 767 | 233 |
Predicted minority | Predicted majority | |
---|---|---|
Actually minority | 810 | 190 |
Actually majority | 366 | 634 |
4.1.3 Class ratio of 1:500
Both CIFAR-10 and CIFAR-100 were used to experiment imbalance ratio 1:500. On CIFAR-10, maximum precision, recall and F1-score on averaging 5-fold-cross-validation are 0.75, 0.95 and 0.70 respectively as shown in Table 5. The maximum precision is slightly lower on CIFAR-100 with 0.60. However, the recall and F1-score which are 0.96 and 0.68 are roughly the same.
4.1.4 Class ratio of 1:50
We demonstrate our model performance on imbalance ratio 1:50 using CIFAR-100 and CIFAR-10. For CIFAR-100, a sudden change occurred between epoch 69 and 70 as follows majority: 57, minority: 59; and majority: 55, minority: 59. A nash equilibrium is attained at epoch 68, with 56 correctly classified minority as well as majority class. At epoch 97 maximum F1-score and recall of 0.66 and 0.77 were obtained respectively, while maximum precision of 0.6 was obtained at epoch 74.
ratio | Precision | Recall | F1 |
---|---|---|---|
1:2500 | 0.72(-) | 0.94(-) | 0.69(-) |
1:1000 | 0.72(-) | 0.96(-) | 0.69(-) |
1:500 | 0.75(0.60) | 0.95(0.96) | 0.70(0.68) |
1:50 | 0.78(0.7) | 0.97(0.86) | 0.69(0.71) |
1:1 | 0.82(0.72) | 0.98(1.0) | 0.72(0.78) |
∗1:1 | -(0.53) | -(0.86) | -(0.63) |
#training | #testing | |||||||
---|---|---|---|---|---|---|---|---|
Model | minor | major | minor | major | imbalance ratio | Precision | Recall | F1 |
eGAN | 30 | 1080 | 1000 | 1000 | 1:36 | 0.51 | 0.97 | 0.67 |
baseline | 30 | 1080 | 1000 | 1000 | 1:36 | 0.5 | 1.0 | 0.67 |
4.1.5 Class ratio of 1:1
We use CIFAR-100 to demonstrate the performance of eGAN on a balanced dataset. We notice that the experimental performance follow the same pattern as imbalanced dataset. Training starts with mostly all the majority correctly classified and all the minority mis-classified. At epoch 48, a nash equilibrium (with threshold less than or equal to 5) is achieved, with 76 and 71 of minority and majority correctly classified respectively. The maximum F1-score of 0.78 is reached at epoch 53 as shown in Table 5. Instead of using 500 samples each of minority and majority class, by training on a single instance of minority and majority sample (∗1:1) of CIFAR-100, we obtained an F1-score of 0.63. This demonstrates the impact of transfer learning on the training.
4.1.6 Class ratio of 1:36
For pneumonia subset of CheXpert dataset with imbalance ratio 1:36, the best performed model achieves 0.51, 0.97, and 0.67 precision, recall, and F1-score respectively. The results shown in Table 6 is evaluated on 1000 of each minority and majority test set. As can be seen in the table, our approach did not beat the baseline classification model because this task is more of an anomaly detection task rather than a classification problem. Also, the pre-trained image classification model source dataset (i.e. ImageNet) is different from the medical domain. Exploring more variants of complex GAN architectures like BigGAN, StyleGAN and ProGAN could possibly help.
5 CONCLUSION
In this work, we demonstrates the capability of a GAN-based unsupervised technique to address class imbalance using pre-trained models. We conducts experiment with varying levels of imbalance ratios in the training dataset. Instead of synthesizing artificial images with the generator for data augmentation, we employ the discriminator as a classifier and formulate the loss function accordingly. Experimental results reveal that transfer learning plays a significant role in the model performance. The performance measure of interest plays a significant role in deciding the trained model from which epoch to deploy in production as the model at different epochs favour different evaluation metrics for example sensitivity. Future work will focus on the usage of this approach for anomaly detection task where the distinguishing features between normal (majority) and abnormal (minority) are less profound. Our work can be further explored in object detection tasks in case of imbalance between foreground and background.
References
- [1] Antreas Antoniou, Amos Storkey and Harrison Edwards “Data augmentation generative adversarial networks” In arXiv preprint arXiv:1711.04340, 2017
- [2] Alessandro Bria, Claudio Marrocco and Francesco Tortorella “Addressing class imbalance in deep learning for small lesion detection on medical images” In Computers in Biology and Medicine 120 Elsevier, 2020, pp. 103735
- [3] Chris Drummond and Robert C Holte “C4. 5, class imbalance, and cost sensitivity: why under-sampling beats over-sampling” In Workshop on learning from imbalanced datasets II 11, 2003, pp. 1–8 Citeseer
- [4] Ian J Goodfellow et al. “Generative adversarial networks” In arXiv preprint arXiv:1406.2661, 2014
- [5] Gao Huang, Zhuang Liu, Laurens Van Der Maaten and Kilian Q Weinberger “Densely connected convolutional networks” In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017, pp. 4700–4708
- [6] Yangru Huang, Yi Jin, Yidong Li and Zhiping Lin “Towards Imbalanced Image Classification: A Generative Adversarial Network Ensemble Learning Method” In IEEE Access 8 IEEE, 2020, pp. 88399–88409
- [7] Jeremy Irvin et al. “Chexpert: A large chest radiograph dataset with uncertainty labels and expert comparison” In Proceedings of the AAAI Conference on Artificial Intelligence 33.01, 2019, pp. 590–597
- [8] Diederik P Kingma and Max Welling “Auto-encoding variational bayes” In arXiv preprint arXiv:1312.6114, 2013
- [9] Alex Krizhevsky and Geoffrey Hinton “Learning multiple layers of features from tiny images” Citeseer, 2009
- [10] Charles X Ling and Victor S Sheng “Cost-sensitive learning and the class imbalance problem” In Encyclopedia of machine learning 2011 Citeseer, 2008, pp. 231–235
- [11] Nasik Muhammad Nafi and William H Hsu “Addressing class imbalance in image-based plant disease detection: Deep generative vs. sampling-based approaches” In 2020 International Conference on Systems, Signals and Image Processing (IWSSIP), 2020, pp. 243–248 IEEE
- [12] Alec Radford, Luke Metz and Soumith Chintala “Unsupervised representation learning with deep convolutional generative adversarial networks” In arXiv preprint arXiv:1511.06434, 2015
- [13] Yanmin Sun, Mohamed S Kamel, Andrew KC Wong and Yang Wang “Cost-sensitive boosting for classification of imbalanced data” In Pattern recognition 40.12 Elsevier, 2007, pp. 3358–3378
- [14] Yaxing Wang et al. “Transferring gans: generating images from limited data” In Proceedings of the European Conference on Computer Vision, 2018, pp. 218–234
- [15] Chaoyun Zhang, Pan Zhou, Chenghua Li and Lijun Liu “A convolutional neural network for leaves recognition using data augmentation” In 2015 IEEE International Conference on Computer and Information Technology; Ubiquitous Computing and Communications; Dependable, Autonomic and Secure Computing; Pervasive Intelligence and Computing, 2015, pp. 2143–2150 IEEE
- [16] Jun-Yan Zhu, Taesung Park, Phillip Isola and Alexei A Efros “Unpaired image-to-image translation using cycle-consistent adversarial networks” In Proceedings of the IEEE International Conference on Computer Vision, 2017, pp. 2223–2232