DINO


Introduction

ViT(Vision Transformers) 는 최근 visual recognition을 위한 convnet의 대안으로 등장하였습니다. ViT가 convnet에 비해 경쟁적이지만, 더 까다로운 계산 + 더 많은 train dataset 필요 + ununique feature propoerties 때문에 아직 명확한 장점을 제공하지 못했습니다. 하지만, NLP에서 좋은 모습과 potential을 보여주었기 때문에, 이에 영감을 받아 ViT에 대한 self supervised pretraining의 영향을 연구한 것이 바로 해당 연구 논문 DINO입니다.

self supervised ViT는 supervised ViT나 convnet에서는 나타나지 않는 몇가지 흥미로운 property가 발견되었습니다.

  • self supervised ViT features는 아래의 그림과 같이 scene layout, 즉 object boundaries를 명시적으로 포함

  • result features의 품질을 향상시키기 위해 ViT의 더 작은 patches를 사용

  • DINO는 cross entropy loss를 사용하여 momentum encoder로 구축된 teacher network의 output을 직접 예측하여 self supervised training을 단순화

  • 기존의 SSL methods는 collapse를 피하기 위해, contrastive loss나 clustering 등을 사용하지만, DINO는 teacher network의 output을 centering과 sharpening함으로써 collapse를 방지

  • self supervised ViT features는 fine-tuning, linear classifier 또는 data augmentation 없이, 기본적인 k-NN classifier와 함께 수행되어, ImageNet에서 78.3%의 top-1 accuracy를 달성. (momentum encdoer와 multi-crop augmentation 결합의 경우 한정)

  • DINO architecture는 유연하고 내부 normalizations을 조정할 필요 없이, convnet과 ViT 모두에서 작동

img


Approach

SSL with Knowledge Distillation

DINO는 아래의 그림과 알고리즘과 같이, teacher network와 student network를 사용하는 knowledge distillation 구조를 사용하였습니다. 즉, teacher network $g_{\theta_t}$의 output과 일치하도록 student network $g_{\theta_s}$를 학습하는 paradigm입니다.

img img

DINO의 idea는 input image $x$가 주어지면, 2개의 networks 모두 $P_s$와 $P_t$로 표시된 $K$ 차원에 대한 확률 분포를 출력합니다. 확률 $P$는 아래의 수식과 같이, network $g$의 output을 softmax function으로 normalization하여 얻습니다. $\tau_s > 0$는 output 분포의 sharpness를 제어하는 temperature parameter입니다.

img

그리고 아래의 수식과 같이, $P_s$와 $P_t$ 최소화하여 $P_s$ 분포를 $P_t$ 분포에 일치시키는 것을 학습합니다. 즉, student network $\theta_s$의 parameter에 대한 cross entropy loss가 됩니다. ($H(a,b) = -a\log{b}$)

img

위 idea를 self supervised learning에 적용한 것이 DINO입니다. 먼저, multicrop을 사용하여 image의 다양한 views나 crop을 구성합니다. 즉, 주어진 image에서 다양한 views의 집합 V를 생성합니다. 이 V에는 2개의 global views $x^g_1$ 및 $x^g_2$ (원본 image의 50% 이상의 넓은 영역을 덮는 $224^2$ 해상도 view)와 작은 해상도의 여러개 local views (원본 image의 작은 영역을 덮는 $96^2$ 해상도 view)가 포함됩니다.

모든 views와 crop은 student를 통해 전달되고 global views만 teacher를 통해 전달되어, “local-to-global” 을 구성합니다. 그리고 의 수식과 같이 loss를 최소화함으로써, student와 teacher의 분포를 일치하도록 학습합니다.

img

2개의 network는 각각 parameters $\theta_s$와 $\theta_t$를 사용하고 동일한 architecture g를 공유합니다. 그리고 SGD optimizer를 사용하여 위의 loss function을 최소화하여 parameter $\theta_s$를 학습합니다.

Teacher network

teacher network는 knowledge distillation과 달리 $priori$로 주어진 teacher $g_{\theta_t}$가 없으므로, student network의 previous iterations으로 구축합니다. 이 부분은 Impact of the choice of Teacher Network 에서 자세히 확인할 수 있습니다.

student network weight에 대해서는 지수 이동 평균(Exponential Moving Average), 즉 momentum encoder를 사용하였습니다. update는 학습중에 0.996 ~ 1로 cosine schedule을 따르는 $\lambda$가 있는 $\theta_t \leftarrow \lambda\theta_t + (1-\lambda)\theta_s$를 사용하였습니다.

원래 momentum encoder는 contrastive learning에서 queue의 대안으로 도입되었습니다. 그러나, DINO framework에서는 queue나 contrastive loss가 없기 때문에 역할이 다르며, self training에 사용되는 mean teacher의 역할에 더 가깝습니다. 실제로, 이 teacher가 exponential decay로 mean을 내는 Polyak-Ruppert와 유사한 model ensembling 형식을 수행합니다. model ensembling을 위해 Polyak-Ruppert mean을 사용하는 것은 model의 성능을 향상시키기 위한 표준 관행입니다. 이 부분도 Analyzing the training dynamic 에서 자세히 확인할 수 있습니다.

이 teacher가 학습 전반에 거쳐 student보다 더 나은 수행을 하고 있는 것도 실험을 통해 확인했으며, 따라서 더 높은 품질의 target features를 제공하여 student의 학습을 제공합니다.

Network architecture

neural network($g$)는 ViT나 ResNet의 backbone($f$)과 projection head($h:g=h\circ f$)로 구성됩니다. pretext task를 수행하고 downstream task에 사용되는 features는 backbone f의 output입니다. projection head는 2048개의 node가 있는 3 layer MLP(multi-layer-perceoptron)과 $l_2$ normalization으로 구성되고, SwAV와 유사한 K dimension의 weight normalized fully connected layer로 구성됩니다.

그리고 predictor를 사용하지 않으므로, student network와 teacher network 모두에서 동일한 architecture가 생성됩니다. 특히, 표준 convnet과 달리 ViT architecture는 기본적으로 batch normalization을 사용하지 않습니다. 따라서, ViT에 DINO를 적용할때, projection head에도 batch normalization을 사용하지 않으므로, 전반적인 architecture에서 batch normalization을 사용하지 않습니다.

Avoiding collapse

기존의 SSL methods는 contrastive loss, clustering, predictor, batch normalization을 통해 collapse를 피합니다. DINO는 model collapse를 피하기 위해 momentum teacher output를 centering하고 sharpening하는 방법을 사용하였습니다. collapse를 피하기 위해 이 방법을 선택하면 batch에 대한 의존도를 낮추고 안정성이 높아집니다. centering은 teacher에게 bias term c를 추가하는 것으로 해석할 수 있습니다. (즉, $g_{t}(x)\leftarrow g_{t}(x) + c(centering)$). centering $c$는 아래의 수식과 같이, exponential MA(moving average)으로 update되고 다양한 batch size에서 잘 작동되었습니다.

img

여기서, $m > 0 $은 rate parameter이고 B는 batch size입니다. sharpening은 teacher softmax normalization에서 temperature $\tau_t$에 대해 낮은 값을 사용하여 얻습니다.

Implementation and evaluation protocols

Vision Transformer

DINO에서 사용된 다양한 network config는 아래의 표와 같습니다. ViT architecture는 N x N 해상도의 겹치지 않는 image patches의 grid를 input으로 사용합니다. DINO에서는 일반적으로 N = 16 이나 N = 8을 사용하였습니다. 그런 다음 patches는 linear layer를 통과하여 embeddings set를 형성합니다. 그리고 DINO에서는 sequence에 token을 추가하였습니다. 이 token의 역할은 전체 sequence에서 정보를 집계하고 output에 projection head h를 연결하는 것입니다.

img

일반적으로 이 token을 class token이라고 하지만, DINO는 self-supervised learning이므로 label이나 supervision은 첨부하진 않았습니다. patch tokens와 cls tokens는 “pre-norm” layer normalization을 통해 Transformer network에 공급됩니다. Transformer는 skip connections과 병렬인 self attention와 feed forward layer의 sequence입니다. self attention layer는 attention mechanism으로 다른 token representation을 살펴봄으로써 token representation을 update합니다.

Implementation details

  • ImageNet datset without labels
  • adamw optimizer and 1024 batch size
  • 16 GPUs when using ViT-S/16 !!
  • linear warmup scaling rule: lr = 0.0005 * batchsize/256
  • after warmup, decay learning rate with a cosine schedule from 0.04 to 0.4
  • temperature $\tau_s$ is set to 0.1 (use linear warmup from 0.04 to 0.07 during 30 epochs)
  • BYOL data augmentations (color jittering, Gaussian blur, solarization) and multi-crop

Evaluation protocols

이전 SSL methods와 비슷하게, standard protocols는 frozen features에 대한 linear classifier를 학습하거나 downstream task의 feature를 fine-tuning하는 것입니다. linear evaluations를 위해 학습 중에 random resize and crops 과 horizontal flips의 augmentations을 적용하고 central crop에 대한 정확도를 확인하였습니다. 그리고, fine-tuning을 위해 pretrained weight로 network를 initialize하고 학습 중에 update합니다.

그러나, 두 evaluation 모두 hyperparameters에 대해 민감했고, 특히 lr을 변경할때 정확도의 큰 차이가 있었습니다. 그래서 간단한 weighted k-NN을 사용하여 evaluation하였습니다. 특히, nearest neighbors의 수를 확인하였을때 20 NN이 대부분의 학습에서 일관되게 잘 작동하였습니다. 이 evaluation protocols은 hyperparameters 조정이나 data augmentation 등이 필요하지 않기 때문에, feature evaluation이 간소화되었습니다.

Main Results

Comparing with SSL frameworks on ImageNet

Comparing with the same architecture

아래 표에서 DINO를 ResNet50 또는 ViT-small과 같은 동일한 architecture의 다른 SSL과 비교하였습니다. ViT-S를 선택한 이유는 parameters 수가 21M vs 23M로 ResNet50와 유사하기 때문입니다. (처리량 : 1237/sec vs 1007 im/sec, supervised performance : 79.3% vs 79.8%).

img

먼저, DINO가 ResNet50에서 최신 기술과 동등한 performance를 발휘하는 것을 관찰할 수 있습니다. 그리고, ViT architecture로 switch할때, DINO는 linear evaluation에서 +3.5%, k-NN evaluation에서 +7.9%로써, BYOL, MoCo v2, SwAV를 능가하는 performance를 보여주었습니다. 여기서 흥미로운 것은, 단순한 k-NN classifier의 performance가 linear evalution과 거의 동등하다는 것입니다 (75.5% vs 77.0%).

이 property는 ViT architecture와 함께 DINO를 사용할때만 나타나며, 기존의 SSL이나 ResNet50에서는 나타나지 않았습니다.

Comparing across architectures

아래의 표는 architecture 전체에서 얻은 best performance를 비교합니다. 이 비교는 직접 비교하는 것이 아니라, 더 큰 architecture로 moving할때 DINO로 훈련된 ViT의 한계를 평가하기 위한 것입니다. DINO로 더 큰 ViT를 훈련하면, performance가 향상되지만, patches(“/8” variants)의 크기로 줄이면, performance에 더 큰 영향을 미칩니다. patches size를 줄이면, parameter가 추가되지 않지만, 실행 시간이 크게 줄어들고 memory usage가 늘어납니다.

그럼에도 불구하고, DINO로 훈련된 8x8 patches가 있는 기본 ViT는 linear evaluation에서 80.1%의 top-1을 달성하고, 이전 SSL SOTA보다 10배 적은 parameter와 1.4배 빠른 실행 시간을 가진 k-NN classifier로 77.4%를 달성하였습니다.

img

classification 외에도, retrieval, object discovery, transfer learning에서 우수한 성능을 보였습니다. 이 부분은 논문을 참고하시기 바랍니다.

Avlation Study of DINO

Importance of the Different Components

SSL에서 다른 구성 요소를 추가하는 것이 DINO의 ViT에 미치는 영향을 study하였습니다. 아래의 표는 구성 요소를 추가하고 제거할때, 다양한 model 변형을 확인한 것입니다.

img

먼저, momentum이 없으면, DINO가 작동하지 않고(2행), collapse를 피하기 위해 SK(Sinkhorn-Knopp)와 같은 고급 task가 필요하다는 것을 관찰할 수 있었습니다 (9행). 그러나, momentum이 있는 경우 SK를 사용하는 것은 거의 영향에 미치지 않았습니다 (3행). 또한, 3행과 9행을 비교하면 performance에 대한 momentum encoder이 중요한 것을 확인할 수 있습니다.

둘째, 4행과 5행에서 DINO의 multi-crop training과 cross-entropy loss가 좋은 features를 얻기 위한 중요한 요소임을 알 수 있습니다. 또한, student network에 predictor를 추가하는 것은 거의 영향을 미치지 않았습니다 (6행).

Importance of the patch size

서로 다른 patches size(16x16, 8x8, 5x5)로 학습된 ViT-S model의 k-NN classifier performance를 비교하였고, 결과는 아래의 그림과 같습니다. 또한, 16x16, 8x8 patches size의 ViT-B도 비교하였고, 모든 models는 300 epochs 학습하였습니다.

img

그 결과, patches size를 줄이면 성능이 크게 향상되는 것을 알 수 있었습니다. parameter를 추가하지 않고도 performance가 크게 향상될 수 있다는 점은 흥미로웠지만, 처리량은 높아지는 것을 알 수 있었습니다. 5x5 patches size는 44 im/s vs 8x8 patches size는 180 im/s.

Impact of the choice of Teacher Network

여기서는 DINO에서 다른 teacher network를 study하였습니다. k-NN protocol을 사용하여 300 epochs에 대해 학습된 model을 비교합니다.

Building different teachers from the student

아래 그림의 오른쪽은 momentum teacher 외에 student의 prevuiys instances에서 teacher를 구축하기 위해 다양한 실험을 비교한 결과입니다.

img

먼저, previous epoch의 student network를 teacher로 사용하는 것을 고려했습니다. 이 방법은 memory bank 또는 clustering hard-disillation의 한 형태로 사용되었습니다.

두번째, previous iteration의 student network를 copy한 student copy를 사용하는 것을 고려했습니다. 이 방법이 작동하려면 더 많은 normalization이 필요합니다.

흥미롭게도, previous epoch의 teacher를 사용하는 것이 collapse되지 않고 MoCo v2나 BYOL과 같은 기존 framework와 비교하였을때, k-NN evaluation에서 performance를 제공하는 것을 관찰하였습니다. momentum encoder를 사용하는 것은 teacher에게 우수한 performance를 제공하지만, 이것은 teacher를 위한 다른 방법을 조사할 여지가 있음을 시사합니다.

Analyzing the training dynamic

momentum teacher가 DINO에서 잘 작동하는 이유를 더 이해하기 위해 위 그림의 왼쪽 그래프에서 확인할 수 있습니다. 그래프를 통해 알 수 있는 것은, teacher가 학습중에 지속적으로 student를 능가하며 ResNet50으로 학습할 때도 동일한 양상을 보인다는 것입니다. momentum을 사용하는 다른 framework에서도 관찰되지 않았으며, teacher가 previous epoch에서 구축된 경우에도 관찰되지 않았습니다.

DINO의 momentum teacher를 기하급수적으로 감소하는 Polyak-Ruppert averaging의 한 형태로 해석하는 것을 추천하였습니다. Polyak-Ruppert averaging은 학습이 끝날 때, network의 performance를 개선하기 위해 model ensembling을 simulation하는데 자주 사용되는 방법입니다. 이 방법은 우수한 performance를 가진 model ensembling을 지속적으로 구축하기 위해 학습 중에 Polyak Ruppert averaging을 적용하는 것으로 해석될 수 있습니다.

Avoiding collapse

여기서는 collapse를 피하기 위한 centering과 target sharpening을 상호보완성에 study합니다. collapse는 2가지의 형태가 존재합니다.

  • input에 관계없이, model output을 모든 dimension에서 균일하게 하는 것.
  • one dimension이 지배하는 것.

centering은 지배적인 one dimension에 의해 유도된 collapse를 피할 수 있지만, 균일한 output을 피하진 못합니다. sharpening은 반대로 영향을 줍니다. cross entropy H를 entropy h와 Kullback-Leibler divergence(“KL”) $D_KL$로 분해하여 이 상호보완성을 보여줍니다.

img

KL이 0이면 일정한 output을 나타내므로, collapse됩니다. 아래의 그림에서 centering과 sharpening이 있거나 없는 학습 동안 entropy와 KL을 확인할 수 있습니다. 하나의 operation이 누락되 경우, KL은 0으로 수렴하여 collapse를 나타냅니다.

img

그러나 entropy h는 centering이 없는 0과 sharpening이 없는 $-log{1/K}$와 같은 다른 값으로 수렴하여 두 작업이 서로 다른 형태의 collapse를 유도함을 나타냅니다. 두 작업을 모두 적용하면 균형이 유지되어 collapse가 발생하지 않습니다.


기존의 SSL은 contrastive learning과 ResNet50으로 지속적인 연구를 진행하였습니다. 하지만 ViT가 최근에 주목받으면서, DINO와 같이 ViT를 활용한 SSL이 점점 제안되는 것 같습니다. 또한, knowledge distllation이라는 student, teacher network를 이용한 학습 방법을 사용했다는 것도 기존의 SSL 방법론과 다르게 접근했다는 점에서 큰 contribution이 있다고 생각합니다.


참고

  1. https://arxiv.org/pdf/2104.14294.pdf