Introduciton
ViT가 NLP에서 좋은 성능을 보이면서, 이에 영감을 받아 SSL에 ViT를 접목시킨 것이 이번 논문입니다. 이번 논문은 기존 SSL 방법인 contrastive learning에 ViT backbone과 fixed random patch projection을 사용하는 간단한 trick을 추가함으로써, 기존 SSL methods의 불안정성 문제를 완화하고 accuracy를 증가시켰습니다.
이외에도, SSL ViT를 학습하기 위한 몇 가지 기본 구성요소(batch size, learning rate, optimizer)의 효과를 조사하고, 여러가지 SSL framework(MoCo, SimCLR, BYOL, SwAV)에서 ViT backbone으로 실험하여 성능을 비교하였습니다.
해당 논문이 나오기 전에 SSL ViT 연구가 이미 존재하였습니다 (비슷한 시기에 나온 DINO 제외). 기존의 SSL ViT 연구는 리뷰는 하지 않았지만, masked auto-encoding으로 학습하는 방법론이었습니다. masked auto-encoding을 사용하는 기존의 방법론과 contrastive learning framework를 사용하는 ViT인 MoCo v3를 비교하였을때, 아래의 표와 같이 더 좋은 performance를 보였습니다. 여기서, MoCo v3는 backbone을 ResNet-50보다 40배 더 많은 연산을 수행하는 ViT-Large, ViT-Huge model로 진행하였습니다.
그리고, ViT-L의 pretrain은 특정 경우의 transfer learning에서 supervised pretrain을 능가하였습니다. 이외에도, 제안하는 SSL ViT model이 더 적은 biases를 사용하여 convnet에 비해 경쟁력 있는 결과를 보여줌으로써, ViT의 potential을 보여주었습니다. 하지만, 이는 저자가 SSL ViT model이 더 개선될 여지가 있다고 볼 수 있다고 하였습니다.
MoCo v3
MoCo v3는 MoCo v1과 v2에서 발생했던 문제인 단순성, 정확성, 확장성을 개선한 모델입니다. MoCo v3의 pseudo code는 아래와 같습니다.
다른 SSL methods와 같이, random data augmentation을 사용하여 2개의 crops(views)을 수행합니다. crops은 encoder인 $f_q$와 $f_k$를 거쳐 output vector $q$와 $k$를 산출하게 됩니다. 여기서 $q$는 “query”처럼 작동하고 그에 맞는 “key”를 찾는 것으로 학습이 진행됩니다. 이러한 학습 방법은 contrastive loss function 중 하나인 InfoNCE로 공식화 할 수 있습니다. InfoNCE로 공식화한 loss function은 아래의 수식과 같습니다.
여기서 $k_+$는 q의 positive sample이고 set {$k^{-}$}는 $q$의 negative sample입니다. $\tau$는 $l_2$ normalization된 $q$와 $k$에 대한 temperature hyper-parameter입니다.
그리고 pseudo code의 $ctr(q1, k2) + ctr(q2, k1)$처럼 대칭 loss를 구축하였습니다. $f_q$ encoder는 ResNet or ViT backbone, projection head, extra prediction head으로 구성하였고 $f_k$는 $f_q$에서 prediction head를 제거하여 구축하였습니다. 이 $f_k$는 $f_q$의 moving average로 update됩니다.
이러한 mechanism으로 구축한 MoCo v3를 이전의 version과 정확도를 비교하였을때, 이전 version에 비해 좋은 성능을 보였습니다. 좋은 성능을 보이는 주된 요인은 prediction head 추가와 큰 batch size 사용이라고 저자는 주장했습니다.
Stability of Self-Supervised ViT Training
기존 SSL에서는 ResNet backbone을 사용하였고 MoCo v3는 backbone을 ViT로 바꾸었습니다. ViT로 바꾸는 것은 간단했지만, 학습 과정에서 불안정했습니다. 이 불안정성을 밝히기 위해 학습 중에 kNN curves으로 performance를 모니터링하였습니다.
Empirical Observations on Basic Factors
Batch size
해당 논문은 batch size, learning rate, optimizer와 같은 기본 요소가 안정성에 어떤 영향을 미치는지 연구하였습니다. 먼저, batch size의 관점에서는 ViT가 아래의 표들과 같이 계산량이 많으며, batch size가 클때 좋은 성능을 보입니다. 이러한 큰 batch size는 최근 SSL의 정확도에도 유리한 이점이 있습니다.
큰 batch size 외에도 다양한 batch size로 실험을 하였으며, 결과는 아래의 그래프와 같습니다.
1k 와 2k batch size는 71.5%와 72.6% linear accuracy로 상당히 부드러운 curves를 생성합니다. 더 큰 batch size는 더 많은 negative sample를 생성하기 때문에, accuracy를 향상시킵니다. 반면, 4K batch size의 curve는 눈에 띄게 불안정하고 accuracy는 72.2%입니다. 2k batch size에 비해 accuracy가 약간 감소하였고 불안정해보이는 것을 확인할 수 있었습니다.
6k batch size의 curve는 더 안 좋은 pattern을 보입니다. 학습이 local적으로 restart되고, local optima를 벗어나 새로운 optima를 찾는다고 볼 수 있습니다. 결과적으로 학습은 발산하지 않았지만, 제일 낮은 accuracy를 보였습니다.
Learning rate
batch size에 이어서, learning rate도 어떤 영향을 미치는지 실험을 진행하였습니다. 일반적으로, batch size가 증가할때 학습률이 조정되는 경우가 많습니다. 먼저 실험을 진행할때, 모든 실험에서 linear scaling rule를 사용하여 진행하였습니다. 즉, $lr \times batch size/256$으로 설정하여 진행하였습니다. 그리고 아래의 그림을 통해서 lr의 영향을 확인할 수 있습니다.
lr이 작을수록 학습은 더 안정적이지만, under fitting되기 쉬웠습니다. 그리고, $lr$ = 0.5e-4는 $lr$ = 1.0e-4보다 accuracy가 1.8% 낮았습니다.(70.4% vs 72.2%). 반면, 더 큰 lr로 훈련하면 덜 안정적입니다. 위의 그림에서 $lr$ = 1.5e-4가 curve에서 불안정성을 띄고 accuracy가 더 낮다는 것을 볼 수 있습니다.
Optimizer
기본적으로 AdamW를 optimizer로 사용하는 것이 ViT 모델을 훈련하기 위한 일반적은 optimizer입니다. 반면, 최근 SSL methods에서 큰 batch size를 위한 LARS optimizer을 기반으로 연구가 진행되고 있습니다. 그래서, LARS와 AdamW를 합친 LAMB optimizer로 연구를 진행하였습니다. 진행한 실험 결과는 아래의 그래프를 통해 확인할 수 있습니다.
적절한 lr이 주어지면 LAMB는 AdamW보다 약간 더 나은 accuracy인 72.5%를 달성하는 것을 알 수 있습니다. 그러나 lr이 크면 accuracy가 급격하게 떨어집니다. 최적의 값보다 $lr=6e-4$과 $lr=8e-4$같은 lr를 사용할 경우, LAMB의 accuracy는 1.6% 과 6.0%로 낮아지는 것을 알 수 있습니다. 그리고 흥미롭게도 학습 curves는 여전히 부드럽지만 중간에 갑자기 저하되는 것을 확인할 수 있습니다. 이 실험을 통해, lr이 적절하게 선택되면 LAMB가 AdamW와 비슷한 정확도를 달성할 수 있음을 발견했습니다.
A Trick for Improving Stability
이러한 모든 실험은 불안정성이 주요 문제임을 알 수 있습니다. 논문에서는 불안정성을 개선하기 위한 간단한 trick을 적용하였고, 그 결과 다양한 경우에서 정확도가 향상되었습니다.
학습 중에 gradient의 급격한 변화로 인해 학습 curves에서 잠깐 하락하는 현상이 발생함을 알 수 있습니다. 이를 자세히 확인하기 위해, 모든 layer의 gradient를 비교하였습니다. 비교 결과는 아래의 그래프와 같이, gradient가 첫번째 layer(patch projection)에서 spike와 같은 튀는 것이 마지막 layer보다 더 일찍 발생하는 것을 확인할 수 있습니다..
이 관찰을 바탕으로 불안정성은 앝은 layer에서 더 일찍 발생하는 것을 알 수 있습니다. 이를 예방하기 위해, 학습 중에 patch projection을 고정하는 방법을 사용하였습니다. 즉, 학습되지 않은 patch를 embedding하기 위해 fixed random patch projection을 사용하는 것입니다.
Comparisons
아래의 그림은 learned patch proj vs random patch proj의 MoCo v3 결과를 보여줍니다. random patch proj의 경우는 더 부드럽고 더 나은 학습 curve로 안정화되어서, $lr=1.5e-4$에서 정확도를 73.4%로 높였습니다. 그리고 더 큰 lr에서는 정확도를 더 크게 향상하였습니다. 이 비교를 통해서 학습의 불안정성이 accuracy에 영향을 미치는 주요한 문제임을 확인할 수 있습니다.
해당 논문은 MoCo 외에도 SimCLR과 BYOL과 같은 다른 SSL methods에서도 불안정성을 확인하였습니다. 그 결과는 아래의 그래프와 같습니다.
random patch proj는 SimCLR과 BYOL 모두에서 안정성을 개선하고 정확도를 각각 0.8%와 1.3% 증가시켰습니다. 그리고, SwAV에도 적용하였을때, 정확도를 65.8%에서 66.4%로 향상하였습니다. 이 간단한 trick으로 모든 SSL framework에서 효과적인 성능을 보였습니다. 하지만, 이러한 trick이 불안정성을 완화하는 것이지 solution은 아니라고 하였습니다. 그래서 저자는 더 개선될 여지가 있다고 하였습니다.
Implementation Details
Optimizer
- default use AdamW
- batch size 4096
- search for lr and weight decay based on 100 epoch results, tehn apply it for longer training
- lr warmup for 40 epochs
- after warmup, lr follows a cosine decay schedule
MLP heads
- projection head is a 3-layer MLP
- prediction head is a 2-layer MLP
- hidden layers of both MLPs are 4096-d and are with ReLU
- output layers of both MLPs are 256-d, without ReLU
- all layers in botch MLPs have BN
Loss
- scale the contrastive loss in by a constant 2$\tau$
- scale is redundant (it can be absorbed by adjusting lr and wd)
- this scale makes it less sensitive to the $\tau$ value when lr and wd are fixed
- set $\tau = 0.2$ as the default
ViT architecture
- input patch size is 16x16 or 14x14
- after projection it results in a sequence of length 196 or 256 for a 224x224 input
- position embedding are added to the sequence (use sine-cosine variant)
- sequence is concatenated with a learnable class token
- sequence is then encoded by a stack of Transformer blocks
- class token after last block is treated as output of backbone, and is input to MLP heads.
Linear probing
- evaluate by linear probing
- after pretraining, remove MLP heads and train a supervised linear classifier on frozen features
- use SGD optimizer with batch size 4096, wd of 0, sweep lr for each case
- 90 epochs (using only random resized croping and flipping aug)
Experimental Results
ImageNet으로 SSL을 수행하고 linear probing으로 evaluation하였습니다. 아래의 표는 ViT 구성을 요약한 것입니다.
Self-supervised Learning Frameworks
해당 논문은 MoCo v3 뿐만 아니라, SimCLR, BYOL, SwAV의 SSL framework에 ViT를 benchmark하였습니다. 모두 동일한 random projection trick을 사용하였고, 각 개별 framework에 대해 lr 및 wd(weight decay)를 sweep하였습니다. 아래의 표는 ViT-S/16 및 ViT-B/16 backbone으로 사용했을때, SSL framework의 성능을 비교한 표입니다. MoCo v3는 다른 framework보다 ViT에 대한 accuracy가 높은 것을 확인할 수 있습니다.
그리고 MoCo v3 및 SimCLR은 아래의 그래프와 같이, ResNet-50보다 ViT-B에서 더 높은 accuracy를 보였습니다.
Ablations of ViT + MoCo v3
MoCo v3는 ViT backbone을 사용하기 때문에, ViT의 구성요소를 여러가지 실험을 통해 성능을 비교하였습니다.
Position embedding
첫번째는 position embedding methods를 비교하여 실험을 진행하였고, 결과는 아래의 표와 같습니다. MoCo v3의 default method는 sin-cos입니다.
learned method는 잘 작동하지만 sin-cos보다 좋지는 못했습니다. 그리고 놀랍게도.. position embedding 없이도 제대로 작동하는 것을 확인할 수 있습니다. 이 실험을 통해서 position을 encoding하는 기능은 1.6%정도만 기여하는 것을 확인할 수 있습니다.
Class token
다음은 ViT에서 class token의 역할에 대해 실험을 진행하였고, 결과는 아래의 표와 같습니다.
cls를 사용하지 않으면(w/o) ViT는 final block 직후에 global average pooling이 사용됩니다. ViT는 final block 이후 추가 layernorm(LN)을 가지며, 이 LN을 유지하고 cls를 제거하면 결과가 훨씬 더 나빠집니다(69.7%). 그러나 이 LN과 cls를 제거하면 결과는 거의 변하지 않습니다(76.3%)… 이 비교실험을 통해서 class token이 필수적이지 않음을 확인할 수 있습니다. 또한 LN의 선택이 차이를 만들 수 있음을 확인할 수 있습니다.
BatchNorm in MLP heads
ResNet과 달리 ViT model은 기본적으로 BN이 없으므로, 모든 BN layer가 MLP head에 있습니다. 아래의 표는 head에 BN이 있는 경우와 없는 경우를 비교한 표입니다.
BN을 제거하면 accurcay가 2.1% 감소하는 것을 확인할 수 있습니가. contrastive learning이 작동하는데 BN이 필수적이지는 않지만, BN을 적절하게 사용하면 accuracy를 향상시킬 수 있음을 확인할 수 있습니다.
Prediction head
MoCo v3는 prediction MLP head를 사용하는데, 사용 여부에 따른 성능을 실험하였고 결과는 아래의 표와 같습니다. prediction MLP head를 제거하면 75.5%의 괜찮은 성능을 보입니다.
Momentum encoder
다음은 momemtum encoder의 momentum coefficient(m)을 비교한 실험이고 결과는 아래의 표와 같습니다.
최적 값은 $m=0.99(default)$입니다. $m=0$의 경우 SimCLR와 유사하고 74.3%의 accuracy는 SimCLR와 비슷합니다. 이를 통해 momentum encoder를 사용하면 2.2%정도 성능이 증가하는 것을 확인할 수 있습니다.
Comparisons with Prior Art
Self-supervised Transformers
SSL와 비교하여 다양한 ViT model을 사용한 MoCo v3의 결과는 아래의 표를 통해 확인할 수 있습니다. iPGT와 masked path pred는 모두 masking autoencodering으로 분류되는 방법들입니다. MoCo 기반의 ViT는 iGPT보다 accuracy가 높고 model이 더 작은 것을 확인할 수 있습니다. 그리고 SSL ViT model인 model이 클수록 accuracy가 더 높은 것을 확인할 수 있습니다.
Comparisons with big ResNets
MoCo v3를 ResNet을 backbone으로 하는 SSL methods와 성능을 비교한 그래프가 아래와 같습니다. ResNet을 대표적으로 사용하는 SimCLR v2와 BYOL을 비교하였을때, MoCo v3 ViT-BN이 높은 성능을 달성하는 것을 확인할 수 있습니다. 또한, 동일한 parameters 수대비 높은 성능을 달성하는 것을 확인할 수 있습니다.
End-to-end fine-tuning
그리고 finetuning을 수행했을때의 성능을 비교한 표가 아래의 표와 같습니다. 모든 case에서 더 좋은 성능을 달성하는 것을 확인할 수 있습니다.
Transfer Learning
transfer learning을 수행했을때의 결과는 아래의 표와 같습니다. 4개의 downstream dataset에 적용하였을때, 거의 대부분의 dataset에서 supervised method보다 더 좋은 성능을 달성하는 것을 확인할 수 있었습니다.
MoCo v3는 ResNet-50 backbone을 대부분 사용하는 SSL methods에서 최근 핫한 ViT backbone을 사용하였습니다. 기존의 contrastive learning method는 사용하고 fixed random projection patch라는 trick을 사용함으로써, 높은 performance를 달성하였습니다. 지금까지 ViT를 backbone으로 하는 SSL이 많이 나왔으며, 앞으로 관련 논문들을 리뷰할 예정입니다.
참고