본문으로 바로가기

Decoupled Knowledge Distillation - CVPR2022

category AI 2023. 8. 12. 21:21

 

 

오늘 포스팅할 논문은 "Decoupled Knowledge Distillation'입니다.

이 논문은 CVPR2022에 기재되었습니다.

자세한 코드 정보는 저자들이 공유한 Github에서 확인하세요!

Contributions

  1. Provide an insightful view to study logit distillation by dividing the classical KD into TCKD and NCKD.
  2. Reveal the limitations of the classical KD loss caused by its highly coupled formulation.
  3. We propose an effective logit distillation method named DKD.

-

우선 Knowledge Distillation의 경우 Hinton이 고안한 개념으로 알고있습니다.

Knowledge Distillation의 주된 방식으로는 logit-based distillation과 feature-based distillation을 말할 수 있습니다.

 

대부분의 KD 논문에서는 더 좋은 성능을 보이고 있는 feature-based distillation을 많이 사용하고 있는 추세입니다.

하지만 Feature disttilation의 경우 extra computational cost 및 storage usage가 요구됩니다.

그와 반면에서 Logit-based Distillation의 경우 상대적으로 작은 computational cost와 storage를 요구한다는 차이점이 있습니다.

-

이 논문에서는 Logit-based KD의 upper bound를 아래의 두 개념을 통해 개선합니다.

  1. Target Classification Knowledge Distillation $($TCKD$)$
    • Only the prediction of the target class is provided while the specific prediction of each non-target class is unknown
    • Transferring knowledge about the difficulty of training samples.
  2. Non-Target Classification Knowledge Distillation$($NCKD$)$
    • Only considering the knowledge among non-target logits
    • Only applying NCKD achieves comparable or even better results than the classic KD
    • It might be "dark knowledge"

The reason why the potential of logit-based distillation is limited

  1. the NCKD loss term is weighted by a coefficient that negatively correlates with the teacher's prediction confidence on the target class. Thus larger prediction scores would lead to smaller weights.
  2. The significance of TCKD and are coupled, i.e., weighting TCKD and NCKD separately is not allowed.
    Such limitation is not preferable since TCKD and NCKD should be separately considered since their contributions are from different aspects.

Figure 1. Illustration of the classical KD and DKD.

-

c개의 class 중 t번째 class를 target으로 하는 모델의 예측 확률 값을 다음과 같이 정의한다.

$\mathbf{P}=[p_1,p_2,...,p_t, ...p_c] \in \mathbb{R}^{1\times C}$

 

$\mathbf{P}$의 i번째 class에 대해서 softmax를 취해 아래와 같이 계산된다..

Eq.1.    $p_i=\frac{exp(z_i)}{\sum_{j=1}^{C}exp(z_j)}$  

 

Eq.1.을 Target과 Non-Target으로 이진 분류했을 때 이를 $\mathbf{b}=[p_t, p_{\setminus t}] \in \mathbb{R}^{1\times2}$로 정의하며, $p_t$와 $p_{\setminus t}$는 아래와 같다.

  • Target: $p_t=\frac{exp(z_t)}{\sum_{j=1}^{C}exp(z_j)}$ 
  • Non-Target: $p_{\setminus t}=\frac{\sum_{k=1,k \neq t}^{C} exp(z_k)}{\sum_{j=1}^{C}exp(z_j)}$

target class에 대해서 고려하지 않은 모델의 예측 확률 값을 $\mathbf{\hat{p}}=[\hat{p}_1, ..., \hat{p}_{t-1}, \hat{p}_{t+1},...,\hat{p}_{c}] \in \mathbb{R}^{1\times (C-1)}$로 정의하며, i번째 값은 아래와 같이 계산된다.

 

Eq.2.    $\hat{p_i}=\frac{exp(z_i)}{\sum_{j=1, j \neq t}^{C}exp(z_j)}=p_i / p_{\setminus t}$ 

 

 

그리고 Teach Model을 $\mathcal{T} $, Student Model을 $\mathcal{S}$로 정의한다

이제 대부분의 정이가 끝났고, KD loss를 refomulation하면 끝이 난다.

-

 

기존에 KD에서 사용되는 KL loss를 Eq.3 Eq.1. Eq.2.를 이용해서 다시 써본다면 아래와 같다.

 

Eq.3.    $\mathbf{KD}=\mathbf{KL}(\mathbf{p^\mathcal{T}\parallel \mathbf{p^\mathcal{S}}})$

 

                   $=p_t^\mathcal{T} log (\frac{p_t^\mathcal{T}}{p_t^\mathcal{S}}) + \sum_{i=1,i\neq t}^{C}p_i^\mathcal{T} log \frac{p_i^\mathcal{T}}{p_i^\mathcal{S}}$

 

                   $=p_{t}^{\mathcal{T}}log(\frac{p_{t}^{\mathcal{T}}}{p_{t}^{\mathcal{S}}})+\sum_{i=1,i\neq t}^{C}\hat{p}_{i}^{\mathcal{T}}p_{\setminus t}^{\mathcal{T}}log(\frac{\hat{p}_{i}^{\mathcal{T}}p_{\setminus t}^{\mathcal{T}}}{\hat{p}_{i}^{\mathcal{S}}p_{\setminus t}^{\mathcal{S}}})$

 

                   $=p_{t}^{\mathcal{T}}log(\frac{p_{t}^{\mathcal{T}}}{p_{t}^{\mathcal{S}}})+\sum_{i=1,i\neq t}^{C}\hat{p}_{i}^{\mathcal{T}}p_{\setminus t}^{\mathcal{T}}(log\frac{\hat{p}_{i}^{\mathcal{T}}}{\hat{p}_{i}^{\mathcal{S}}}+log\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})$

 

                   $=p_{t}^{\mathcal{T}}log(\frac{p_{t}^{\mathcal{T}}}{p_{t}^{\mathcal{S}}})+\sum_{i=1,i\neq t}^{C}\hat{p}_{i}^{\mathcal{T}}p_{\setminus t}^{\mathcal{T}}(log\frac{\hat{p}_{i}^{\mathcal{T}}}{\hat{p}_{i}^{\mathcal{S}}})+\sum_{i=1,i\neq t}^{C}\hat{p}_{i}^{\mathcal{T}}p_{\setminus t}^{\mathcal{T}}log(\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})$

 

이 때, $p_{\setminus t}^{\mathcal{T}}$와 $p_{\setminus t}^{\mathcal{S}}$는 class index $i$와 관계 없으니 $ \sum{} $앞으로 빼낼 수 있다.

 

$\sum_{i=1,i\neq t}^{C}\hat{p}_{i}^{\mathcal{T}}p_{\setminus t}^{\mathcal{T}}log(\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})
=p_{\setminus t}^{\mathcal{T}}log(\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})
\sum_{i=1,i\neq t}^{C}\hat{p}_{i}^{\mathcal{T}}$


                                   $=p_{\setminus t}^{\mathcal{T}}log(\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})$

 

그래서 최종적으로 정리된 수식은 아래와 같다.

 

$\mathbf{KD}=\left \{   p_{t}^{\mathcal{T}}log(\frac{p_{t}^{\mathcal{T}}}{p_{t}^{\mathcal{S}}})
+p_{\setminus t}^{\mathcal{T}}log(\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})\right \}
+p_{\setminus t}^{\mathcal{T}}
\left \{ \sum_{i=1,i\neq 1}^{C} \hat{p}_i^{\mathcal{T}}log(\frac{\hat{p}_i^\mathcal{T}}{\hat{p}_i^\mathcal{S}})\right \}$

 

        $=\left \{   p_{t}^{\mathcal{T}}log(\frac{p_{t}^{\mathcal{T}}}{p_{t}^{\mathcal{S}}})
+p_{\setminus t}^{\mathcal{T}}log(\frac{p_{\setminus t}^{\mathcal{T}}}{p_{\setminus t}^{\mathcal{S}}})\right \}
+(1-p_{t}^{\mathcal{T}})
\left \{ \sum_{i=1,i\neq 1}^{C} \hat{p}_i^{\mathcal{T}}log(\frac{\hat{p}_i^\mathcal{T}}{\hat{p}_i^\mathcal{S}})\right \}$

 

        $=\mathbf{KL}(\mathbf{b}^{\mathcal{T}}\parallel \mathbf{b}^{\mathcal{S}})
+(1-p_{t}^{\mathcal{T}})\mathbf{KL}(\hat{\mathbf{b}}^{\mathcal{T}}\parallel \hat{\mathbf{b}}^{\mathcal{S}})$

 

        $=\mathbf{TCKD} + (1-p_{t}^{\mathcal{T}})\mathbf{NCKD}$

 

요약

학습이 진행 됨에 따라서 Teacher Model은 target class에 대해서 high confidence를 가지게 된다.

그렇게 되면 $p_t$가 1이 이 되어 버림으로써 NCKD$($Non-Target Class$)$에 대한 정보를 놓치게 된다.

 

또한, 더 좋은 knowledge distillation을 위해서 더 강력한 Teacher Model을 사용하면 할 수록 $p_t$가 1에 수렴하게 되고

$p_t$가 1이 되면 NCKD의 영향은 0으로 수렴하게 된다.

 

-

Experiment Results

단순히, KD loss를 decoupling & refolmulating했더니 성능이 1% point나 올라감을 보여준다.


MisoYuri's Deck
블로그 이미지 MisoYuri 님의 블로그
VISITOR 오늘 / 전체