Given a good teacher, the hope of Pseudo Labels is that the obtained would ultimately achieve a low loss on labeled data, i.e.
The ultimate student loss on labeled data is also a "function" of .
Therefore, we could further optimize with respect to :
Intuitively, by optimizing the teacher’s parameter according to the performance of the student on labeled data, the pseudo labels can be adjusted accordingly to further improve student’s performance.
However, the dependency of on is extremely complicated, as computing the gradient requires unrolling the entire student training.
To make Meta Pseudo Labels feasible, we borrow ideas from previous work in meta learning [40, 15] and approximate the multi-step with the one-step gradient update of :
The practical teacher objective in Meta Pseudo Labels:
Note that, if soft pseudo labels are used, i.e. is the full distribution predicted by teacher, the objective above is fully differentiable with respect to and we can perform standard back-propagation to get the gradient.
However, in this work, we sample the hard pseudo labels from the teacher distribution to train the student. We use hard pseudo labels because they result in smaller computational graphs which are necessary for our large-scale experiments in Section 4.
For smaller experiments where we can use either soft pseudo labels or hard pseudo labels, we do not find significant performance difference between them.
A caveat of using hard pseudo labels is that we need to rely on a slightly modified version of REINFORCE to obtain the approximated gradient of in Equation 3 with respect to .
More interestingly, the student’s parameter update can be reused in the one-step approximation of the teacher’s objective, which naturally gives rise to an alternating optimization procedure between the student update and the teacher update
Student: draw a batch of unlabeled data , then sample from teacher’s prediction, and optimize objective 1 with SGD:
Teacher: draw a batch of labeled data , and “reuse” the student’s update to optimize objective 3 with SGD:
Meta Pseudo Labels works even better if the teacher is jointly trained with other auxiliary objectives. Therefore, in our implementation, we augment the teacher’s training with a supervised learning objective and a semi-supervised learning objective.
For this experiment, we generate our own version of the TwoMoon dataset.
Dataset | Image Resolution | #-Labeled Examples | #-Unlabeled Examples | #-Test Set |
---|---|---|---|---|
CIFAR-10-4K | 32x32 | 4,000 | 41,000 | 10,000 |
SVHN-1K | 32x32 | 1,000 | 603,000 | 26,032 |
ImageNet-10% | 224x224 | 128,000 | 1,280,000 | 50,000 |
After training both the teacher and student with Meta Pseudo Labels, we finetune the student on the labeled dataset.
Finetuning phase:
Since the amount of labeled examples is limited for all three datasets, we do not use any heldout validation set. Instead, we return the model at the final checkpoint.
To ensure a fair comparison, we only compare Meta Pseudo Labels against methods that use the same architectures and do not compare against methods that use larger architectures.
We also do not compare Meta Pseudo Labels with training procedures that include self-distillation or distillation from a larger teacher [8, 9].
Since these methods do not share the same controlled environment, the comparison to them is not direct, and should be contextualized as suggested by [48].
The purpose of this experiment is to verify if Meta Pseudo Labels works well on the widely used ResNet-50 architecture [24] before we conduct more large scale experiments on EfficientNet (Section 4).
JFT dataset
We also make sure that none of the 12.8 million images that we use overlaps with the ILSVRC 2012 validation set of ImageNet. This procedure of filtering extra unlabeled data has been used by UDA [76] and Noisy Student [77].
Training on ImageNet + JFT
Fintuning student
Supervised learning using ResNet-50
Semi-supervised learning
This is particularly impressive since Billion-scale SSL pre-trains their ResNet-50 on weakly-supervised images from Instagram.
Due to the memory footprint of our networks, keeping two such networks in memory for the teacher and the student would vastly exceed the available memory of our accelerators.
We thus design a hybrid model-data parallelism framework to run Meta Pseudo Labels.
Specifically, our training process runs on a cluster of 2,048 TPUv3 cores.
We divide these cores into 128 identical replicas to run with standard data parallelism with synchronized gradients.
Within each replica, which runs on 2,048/128=16 cores, we implement two types of model parallelism.
Within each replica, which runs on 2,048/128=16 cores, we implement two types of model parallelism.
We implement our hybrid data-model parallelism in the XLA-Sharding framework [37].
With this parallelism, we can fit a batch size of 2,048 labeled images and 16,384 unlabeled images into each training step.
We train the model for 1 million steps in total, which takes about 11 days for EfficientNet-L2 and 10 days for EfficientNet-B6-Wide. After finishing the Meta Pseudo Labels training phase, we finetune the models on our labeled dataset for 20,000 steps.
Given the expensive training cost of Meta Pseudo Labels, we design a lite version of Meta Pseudo Labels, termed Reduced Meta Pseudo Labels. (Appendix E)
To avoid using proprietary data like JFT, we use the ImageNet training set as labeled data and the YFCC100M dataset [65] as unlabeled data.
Reduced Meta Pseudo Labels allows us to implement the feedback mechanism of Meta Pseudo Labels while avoiding the need to keep two networks in memory.
We achieve 86.9% top-1 accuracy on the ImageNet ILSRVC 2012 validation set with EfficentNet-B7.
Pseudo Labels applied to improve the tasks:
Vanilla Pseudo Labels methods keep a pre-trained teacher fixed during the student's learning, leading to a confirmation bias [2] when the pseudo labels are inaccurate.
Other typical SSL methods often train a single model by optimizing an objective function that combines a supervised loss on labeled data and an unsupervised loss on unlabeled data.
Self-supervised losses typically encourage the model to develop a common sense about images, such as
Label propagation losses typically enforce that the model is invariant against certain transformations of the data such as
Meta Pseudo Labels is distinct from the aforementioned SSL methods in two notable ways.
The teacher in Meta Pseudo Labels uses its softmax predictions on unlabeled data to teach the student.
These softmax predictions are generally called the soft labels, which have been widely utilized in the literature on knowledge distillation [26, 17, 86].
Outside the line of work on distillation, manually designed soft labels, have also been shown to improve models' generalization
Both of these methods can be seen as adjusting the labels of the training examples to improve optimization and generalization.
Similar to other SSL methods, these adjustments do not receive any feedback from the student’s performance as proposed in this paper.
We use Meta in our method name because our technique of deriving the teacher’s update rule from the student’s feedback is based on a bi-level optimization problem which appears frequently in the literature of meta-learning.
Similar bi-level optimization problems have been proposed to optimize a model’s learning process, such as:
Meta Pseudo Labels uses the same bi-level optimization technique in this line of work to derive the teacher’s gradient from the student’s feedback.
The difference between Meta Pseudo Labels and these methods is that Meta Pseudo Labels applies the bi-level optimization technique to improve the pseudo labels generated by the teacher model.
In this paper, we proposed the Meta Pseudo Labels method for semi-supervised learning.
Key to Meta Pseudo Labels is the idea that the teacher learns from the student’s feedback to generate pseudo labels in a way that best helps student’s learning.
The learning process in Meta Pseudo Labels consists of two main updates:
Experiments on standard low-resource benchmarks such as CIFAR-10-4K, SVHN-1K, and ImageNet-10% show that Meta Pseudo Labels is better than many existing semi-supervised learning methods.
Meta Pseudo Labels also scales well to large problems, attaining 90.2% top-1 accuracy on ImageNet, which is 1.6% better than the previous state-of-the-art [16].
The consistent gains confirm the benefit of the student’s feedback to the teacher.