본문으로 바로가기

Morphing Tokens Draw Strong Masked Image Models

category AI 2024. 11. 5. 19:26

Link: https://arxiv.org/abs/2401.00254

Introduction or Motivation

Introduce a novel token contextualization method, Dynamic Token Morphing (DTM)

  • Token morphing is a process that links contextually similar tokens and aggregates them for coherent and comprehensive representations.
  • encode the token-wise target representations and derive matching relations among to-kens through DTM.
  • The merging process is performed for both tokens regarding the matching relation and aligning each morphed token with the corresponding morphed target tokens.
  • The range of morphing can vary from a single token to the entire token, covering from token-wise to image-level representation learning.
  • While there are many options, we simply opt for bipartite matching for morphing, achieving both efficiency and efficacy.

Spatial Inconsistency

  • CLIP-B/16, 98 tokens for token aggregation from 196 total tokens
  • spatial 정보가 깨진 상태에서 zero-shot or linear probing하는 경우 보다 spatial 정보를 지키는(with Token Aggregation)을 했을 때 성능이 오른다
  • → 당연한게 아닐까? MIM’s reconstruction training이 아닌 classification에서 partial token을 제공할 때 token aggregation을 함으로써 중요한 부분 또는 context를 유지하는게 유리할거기 때문이다.

Spatial Inconsistency in Representation Learning

Assumption: the spatial inconsistency in the self-distillation signal could potentially diminish the quality of representations.

  • Pre-train ViT-B/16 for 50 epochs and perform linear probing for 50 epochs on ImageNet-1K
  • Token morphing outperforms that of the token-wise distillation approach
    • bipartite matching and super-pixel clustering

Method - Dynamic Token Morphing

Definition:

  • Contextually aggregate tokens to derive random numbers of morphed tokens aiming to encapsulate diversified semantic information.

Observations:

  1. Pre-trained models often produce noisy, spatial inconsistency token-level targets for learning
  2. Token aggregation methods could handle spatial inconsistency but are insufficient as a supervisory signal.
  3. a well-designed method is needed to consider context and reduce noise more effectively.

The core idea of DTM:

  • The DTM module is straightforwardly added to the baseline MIM scheme.
  • Token morphing function $\phi_R:\mathbf{R}^{N \times d} \rightarrow \{0,1\}^{\bar{n}\times N}$ based on morphing schedule $R$.
    • calculates similarity using a matching algorithm like Bipartite Matching
    • returns a token morphing matrix $M=[M_{ij}]\in\{0,1\}^{\bar{n}\times N}$
      • $\bar{n}$: number of token groups after morphing

morphing numerous tokens enhances the denoising effect while morphing fewer tokens retains detailed token representations.

Since morphing a fixed number of tokens may benefit only one side, DTM balances these advantages by morphing a diverse range of tokens.

How to Dynamically Morphing?

  1. Dynamic Scheduler
    • sample the final number of morphed tokens and iteration number
      • $\bar{n} \sim \mathcal{U}(\bar{N}, N)$ from uniform distribution
        • $\bar{N}$: minimum number of morphed tokens
      • $\bar{k} \sim \mathcal{U}(1,K)$ from uniform distribution
        • $K$: maximum number of iterations for sampling
    • $r_p$: sequence of token numbers that determines the number of tokens to morph for each iteration
      • $R=\{r_p\}^k_{p=1} ~\text{and}~ \sum^{k}_{p=1}r_p=N-\bar{n}$
      $$ r_p = \begin{cases} \left\lfloor \frac{(N - \bar{n})}{k} \right\rfloor, & \text{if } p < k \\N - \bar{n} - (k-1)\left\lfloor \frac{(N - \bar{n})}{k} \right\rfloor, & \text{if } p = k \end{cases} $$
  2. Token Morphing via Scheduler
    • Split tokens into two groups, with each token in the first group matched to its closest cosine similarity counterpart in the second group.
    • obtain the token morphing matrix $M$ from the target token representations $\{\text{v}i\}^N{i=1}$
      • $M_{ij}=1$ indicates that the $j^{th}$ token $\text{v}_j$ will be aggregated to the $i^{th}$ token $\text{v}_i$.
    • Constraints:
      • Every token should be assigned to a specific cluster, even in cases where it forms a single token cluster
      • Each token should retain its exclusive association with a single cluster
      • representations of online morphed token $\hat{u}i =\sum_j{M{ij}u_j}$
      • representations of target morphed token $\hat{v}i =\sum_j{M{ij}v_j}$
  3. Aligning Morphed Tokens
    • $L_{DTM}(\bar{n},k)=\sum^{\bar{n}}_{i=1}w_id(\hat{u}_i, \hat{v}_i)$
      • $d(\cdot)$: distance function
      • $w_i=\sum_j{M_{ij}}$ : a number of tokens aggregated for the $i^{th}$ online and target morphed tokens
      • $\bar{n}=N$이면 token-wise loss고 $\bar{n}=1$이면 image-level loss가 됩니다.

Experiment

실험 결과가 epoch 대비 좋습니다.

다만 supervision이 CLIP B/16이기 때문에 기존 MAE같은 Pixel Reconstruction와는 다릅니다.

이는 일종의 distillation으로 성능이 향상되었다고 분석할 수 있습니다.

예로써 BEiT v2 와 성능을 비교했을 때 차이가 크지 않습니다.

 


MisoYuri's Deck
블로그 이미지 MisoYuri 님의 블로그
VISITOR 오늘 / 전체