Yitong Chen+1, Yanjun Lin+1, Zheng Qin+1, Siyin Wang+1, Xingsong Ye+1
1School of Computer Science, Fudan University.
+Equal contribution
——Artificial Intelligence Course (COMP130031.02) Project at Fudan University.
To address the issues of limited training data and overfitting in few-shot image classification, we propose two approaches: leveraging a large amount of unlabeled data and employing data augmentation techniques. For the former, we introduce an "AI learn from AI" method that enables the model to learn from unlabeled data more effectively. Experimental results demonstrate that incorporating unlabeled data can improve model performance, and knowledge distillation maintains minimal performance loss while reducing the number of parameters. Additionally, data augmentation techniques significantly enhance model performance by applying image transformations in five different ways and adding them to the training set.
Transfer Learning, Knowledge Distillation, Data Augmentation, Image Classification, Few-shot Learning
In the field of image classification, few-shot learning is a challenging task with two main difficulties. Firstly, limited training data often leads to poor performance. For example, compared to fine-tuning on the entire CIFAR-100 dataset, ViT-B/16 accuracy drops significantly from 90.20 to 46.31. Secondly, overfitting is a common issue, where the model performs well on the validation set but poorly on the test set. For instance, the validation accuracy reaches 34.83 on the Country211 dataset, while the test accuracy is only 3.24.
To deal with the problem of limited data, we focus on leveraging a large amount of unlabeled data and enabling the model to learn knowledge from it through pseudo-labeling, knowledge distillation, etc. Additionally, we tackle overfitting by employing data augmentation techniques that allow the model to learn from a wider range of sample variations and improve its robustness.
Regarding the utilization of unlabeled data, we propose an "AI learn from AI" approach where one AI model learns from the knowledge of other AI models. In this method, we use an AI model like CLIP to label unlabeled data and incorporate these labels as pseudo-labeled data in the training process. We also use smaller models to learn the soft label probability distribution of larger models, enabling the transfer of additional information and knowledge.
Through extensive experiments conducted on five benchmark datasets and three different experimental settings, we observe that incorporating unlabeled data can improve model performance. Furthermore, the application of knowledge distillation shows minimal performance degradation while reducing the number of parameters. Finally, it is worth noting that data augmentation techniques significantly enhance model performance. By generating transformed or modified images in five different ways, adding this portion of data to the training set leads to significant improvements.
For a large amount of unlabeled data, the conventional approach is to use the model itself trained on small samples for labeling. However, the performance of the model trained on small samples limits the full utilization of the data. To address this, we propose the idea of "AI learn from AI" to enable the model to learn from the performance of other AI models. Using the SOTA model of each dataset poses a problem of category label misalignment with our data.
Therefore, we introduce the CLIP method to process unlabeled data, which can directly match images with textual labels, saving alignment time. Moreover, CLIP (ViT-L/14) performs remarkably well on these five datasets, with an average accuracy of 70 and even surpassing the open-source SOTA models on certain datasets. Additionally, we incorporate template-based prompts (e.g., 'a photo from ..., it's a country') into the textual labels, further enhancing the accuracy of CLIP.
With the labeled data obtained through CLIP, we can apply semi-supervised learning, which combines labeled and unlabeled data for training. We select the topk confident predictions as pseudo-labeled data and leverage them as a complementary part of the training set.
Knowledge distillation is employed to transfer knowledge from large models to small models. The objective is to enable the small model to mimic the performance of the large model by learning from its output probability distribution. In our case, we use the probabilities predicted by the large model as the soft labels, which provide additional information for training the small model.
During the training process, we combine the labeled data from the previous pseudo-labeling step with the soft labels obtained from the large model. This enables the small model to leverage the knowledge distilled from the large model while being trained on small samples.
Task | Model | n_param(weight) | Cifar100 | Country211 | Food101 | Pets37 | Cars196 | score |
---|---|---|---|---|---|---|---|---|
B/C | ViT-B/16 (raw) | 86100104 (0.76) | 73.19 | 5.67 | 62.69 | 80.98 | 59.64 | 43.02 |
B/C | ViT-B/16 (0.5) | 86100104 (0.76) | 78.59 | 8.26 | 87.24 | 91.25 | 58.34 | 49.35 |
B/C | ViT-B/16 (0.9) | 86100104 (0.76) | 91.65 | 9.53 | 89.57 | 92.23 | 74.48 | 54.50 |
B/C | ViT-T/16 (0.9)+distill | 5605862 (0.97) | 82.47 | 6.00 | 81.49 | 83.65 | 50.11 | 59.28 |
A | ViT-B/16 (0.9) | 86100104 (0.76) | 86.54 | 9.30 | 84.95 | 91.58 | 75.60 | 53.17 |
A | ViT-T/16 (0.9) | 5605862 (0.97) | 66.53 | 3.89 | 81.30 | 82.07 | 26.21 | 50.44 |
A | ViT-T/16 (0.9)+distill | 5605862 (0.97) | 84.81 | 5.95 | 81.95 | 86.51 | 55.17 | 61.41 |
The parentheses after the model names indicate the confidence level of the unlabeled data used, where "raw" means no confidence filtering was applied. The table data are based on 5 sets of data augmentation methods to increase the data.
Data augmentation techniques are essential for few-shot learning. By applying various transformations to the input images, we can generate new samples that capture different variations and improve the model's ability to generalize.
In our approach, we employ five different image transformations: random cropping, horizontal flipping, random rotation, color jittering, and random erasing. Each transformation is applied with certain probability and parameter settings, resulting in a diverse set of augmented images.
During the training process, the original training set is augmented with the transformed images. This expanded dataset allows the model to learn from a broader range of variations and improves its robustness to different image conditions.
Table A: Data augmentation without increasing the data
flip | rotation | color-jitter | random-choice | random-crop | Grade |
---|---|---|---|---|---|
N | N | N | N | N | 16.8 |
N | N | 0.3 | Y | Y | 21.1 |
0.5 | 30 | N | N | N | 19.6 |
0.5 | 30 | 0.3 | Y | N | 20.8 |
Table B: Data augmentation with increasing the data
Task | flip | rotation | color-jitter | Group | Grade |
---|---|---|---|---|---|
B | 0.1 | 30 | 0.3 | 1、2、4 | 35 |
B | 0.1 | 30 | 0.3 | 1、2、3、4、5 | 41.8 |
C | 0.1 | 30 | 0.3 | 1、2、4 | 38.9 |
C | 0.1 | 30 | 0.3 | 1、2、3、4、5 | 43 |
We conduct extensive experiments on five benchmark datasets: CIFAR-10, CIFAR-100, Country211, Food100, and Stanford-Cars. For each dataset, we evaluate our approach using three different experimental settings: few-shot learning, semi-supervised learning, and knowledge distillation.
The experimental results demonstrate the effectiveness of our proposed approach. Incorporating unlabeled data through pseudo-labeling improves the model's performance consistently across different datasets and experimental settings. Knowledge distillation enables the small model to learn from the large model's knowledge while maintaining a minimal performance loss. Finally, data augmentation techniques significantly enhance the model's performance, leading to substantial improvements on all tested datasets.
In this paper, we propose two approaches to address the challenges of limited training data and overfitting in few-shot image classification. Leveraging a large amount of unlabeled data through pseudo-labeling and knowledge distillation improves the model's performance consistently. Additionally, data augmentation techniques enhance the model's ability to generalize and improve its robustness to different variations.
Our experimental results demonstrate the effectiveness of these approaches on five benchmark datasets. The proposed methods can be readily applied to other image classification tasks, especially those with limited training data.
[1] G. Hinton, O. Vinyals, and J. Dean. Distilling the knowledge in a neural network. In NIPS Deep Learning and Representation Learning Workshop, 2015.
[2] A. Radford, J. W. Kim, C. Hallacy, A. Ramesh, G. Goh, S. Agarwal, G. Sastry, A. Askell, P. Mishkin, J. Clark, G. Krueger, and I. Sutskever. Learning transferable visual models from natural language supervision, 2021.
[3] Q. Sun, X. Li, Y. Liu, S. Zheng, T. Chua, and B. Schiele. Learning to self-train for semi-supervised few-shot classification. CoRR, abs/1906.00562.
@article{bigdan202307,
title={BigDan: Better Image-classification Grade with Distillation and Augmentation Network},
author={Chen, Yitong and Lin, Yanjun and Qin, Zheng and Wang, Siyin and Ye, Xingsong},
course={Artificial Intelligence A (COMP130031.02), School of Computer Science, Fudan University},
year={2023}
}
人工智能课程大炼丹(Dan)炉
Two-member team can be called Fudan, so five-member team should be called Bigdan.
Our code base is developed and tested with PyTorch 1.7.0, TorchVision 0.8.1, CUDA 10.2, and Python 3.7.
conda create -n baseline python=3.7 -y
conda activate baseline
conda install pytorch==1.7.0 torchvision==0.8.1 cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
Loading pre-trained weights is allowed. You can use the pre-trained model under ImageNet-1k, while other datasets like ImageNet-21k, CC3M, LAION, etc., are not allowed.
Five datasets are given, which include: '10shot_cifar100_20200721','10shot_country211_20210924','10shot_food_101_20211007','10shot_oxford_iiit_pets_20211007','10shot_stanford_cars_20211007'
The executable pretrained models are offered by timm
. You can check and use the offered pretrained timm models.
python main.py --batch-size 64 --data-path ../../share/course23/aicourse_dataset_final/ --output_dir output/baseline --epochs 50 --lr 1e-4 --weight-decay 0.01
There are three modes to execute the code.
- Operate on individual dataset seperately. You can change
--dataset_list
to achieve it. - Operate on known datasets. The dataset which given images belong to will be offered. You can check the
--known_data_source
option. - Operate on unknown datasets. The dataset which given image belong to will not be offered. You should predict both datasets that images belong to and images' corresponding labels. You can check the
--unknown_data_source
option.
After obtaining the checkpoint of certain modes, you should operate --test_only
to produce a prediction json file pred_all.json
. The file will be produced under your output directory.
python main.py --batch-size 64 --data-path ../../share/course23/aicourse_dataset_final/ --output_dir output/baseline --epochs 50 --lr 1e-4 --weight-decay 0.01 --test_only
You should submit a zip file containing the pred_all.json
file into the colab website.