ICCV 2023 Accepted Paper
Link: https://arxiv.org/abs/2210.04845
Summary
- FSOD system must fulfil the following desiderata:
- it must be used as is, without requiring any fine-tuning at test time
- it must be able to process an arbitrary number of novel objects concurrently
- while supporting an arbitrary number of examples from each class
- it must achieve accuracy comparable to a closed system
- few-shot detection transformer (FS-DETR) based on visual prompting that can address both desiderata (a) and (b)
- Two Keys:
- feed the provided visual templates of the novel classes as visual prompts during test time
- “stamp” these prompts with pseudo-class embeddings (akin to soft prompting), which are then predicted at the output of the decoder.
Introduction or Motivation
- Few-Shot Object Detection (FSOD):
- the problem of detecting a novel class not seen during training and, hence, can potentially address many of the aforementioned challenges.
- the requirement for re-training makes them significantly more difficult to deploy on the fly and in real-time or on devices with limited capabilities for training.
- the visual template(s) from the new class(es) are used in two ways during test time:
- in FS-DETR’s encoder to filter the backbone’s image features via cross-attention, and more importantly
- as visual prompts in FS-DETR’s decoder, “stamped” with special pseudo-class encodings and prepended to the learnable object queries
Method
- Definition of Novel set and Base set:
- $C=C_{novel} \cup C_{base}$
- $C_{novel} \cap C_{base} = \empty$
Template encoding:
- Template encoding:
- Let $T_{i,j} \in \mathbb{R}^{H_p \times W_p \times 3}$, where $i \in \{1, \ldots, m\}$ and $j \in \{1, \ldots, k\}$, be the template images of the available classes (sampled from $C_{\text{base}}$ during training).
- $m$ is the number of classes at the current training iteration (can vary).
- $k$ is the number of examples per class (k-shots)
- A CNN backbone (e.g., ResNet-50) generates template features $\mathbf{X} = \text{CNN}(\mathbf{T})$, with $\mathbf{X} \in \mathbb{R}^{mk \times d}$
- using either average or attention pooling
- Let $T_{i,j} \in \mathbb{R}^{H_p \times W_p \times 3}$, where $i \in \{1, \ldots, m\}$ and $j \in \{1, \ldots, k\}$, be the template images of the available classes (sampled from $C_{\text{base}}$ during training).
Pseudo-class embeddings:
- Propose to dynamically and randomly associate, at each training iteration, the $k$ template prompts in $\mathbf{X}$ belonging to the $i^{th}$ class (for that iteration) with a pseudo-class represented by a pseudo-class embedding $\mathbf{c}_i^s \in \mathbb{R}^d$, which are added to the templates as follows:
- $\mathbf{X}^s = \mathbf{X} + \mathbf{C}^s$ where $\mathbf{C}^s \in \mathbb{R}^{mk \times d}$ contains the pseudo-class embeddings for all templates at the current iteration.
- The pseudo-class embeddings are initialized from a normal distribution and learned during training.
- They are not determined by the ground-truth categories and are class-agnostic.
- During each inference step, we arbitrarily associate to a template prompt (belonging to some class) the $i^{th}$ embedding.
- The goal is to retrieve the pseudo-class $i$ .
- The actual class information is not used.
- Since the assigned embedding changes at every iteration, there is no correlation between the actual classes and the learned embeddings.
- each decoded object query $o_i$ in $\mathbf{O}$ will attempt to predict a pseudo-class using a classifier.
- Pseudo-class embeddings add a signature to each visual prompt
- allowing the network to track the template within and dissociate it from the rest of the templates belonging to a different class.
- As transformers are permutation-invariant, these vectors are required to track the visual prompt within the model.
- Pseudo-class embeddings add a signature to each visual prompt
Templates as visual prompts
- $\mathbf{O}' = [\mathbf{X}^s~ \mathbf{O}], \quad \mathbf{O}' \in \mathbb{R}^{(mk + N) \times d}$
FS-DETR encoder
- 기존 DETR과 다른 점은 Encoder layer에 MHCA가 추가 되었다는 점이다.
- 이때 MHCA에는 pseudo-class embedding이 추가 입력으로 사용된다.
- $\mathbf{Z}' = \text{MHSA}(\text{LN}(\mathbf{Z}^{l-1})) + \mathbf{Z}^{l-1}$
- $\mathbf{Z}'' = \text{MHCA}(\text{LN}(\mathbf{Z}'), \mathbf{X}^s) + \mathbf{Z}'$
- $\mathbf{Z}^l = \text{MLP}(\text{LN}(\mathbf{Z}'')) + \mathbf{Z}''$
FS-DETR decoder
In each decoding layer $l$, the output features from the previous layer $\mathbf{V}^{l-1}$ are processed as follows:
$\mathbf{V}' = \text{MHSA}(\text{LN}(\mathbf{V}^{l-1}) + \mathbf{O}') + \mathbf{V}^{l-1}$
$\mathbf{V}'' = \text{MHCA}(\text{LN}(\mathbf{V}'), \mathbf{O}', \mathbf{Z}') + \mathbf{V}'$
$\mathbf{V}^l = \text{MLP}(\text{LN}(\mathbf{V}'')) + \mathbf{V}''$
- Where $\mathbf{V}^0 = \begin{bmatrix} \mathbf{X}^s & \text{zeros}(N, d) \end{bmatrix}$
- Separate MLPs are used to process the decoder’s features
- $\mathbf{V} = \begin{bmatrix} \mathbf{V}{\mathbf{X}^s} & \mathbf{V}{\mathbf{O}} \end{bmatrix} )$ for templates $( \mathbf{V}{\mathbf{X}^s})$ and object queries $\mathbf{V}{\mathbf{O}}$:
FS-DETR training and loss functions
- $L = \sum_{i=1}^N \lambda_1 L_{\text{CE}}(c_i^s, \hat{p}{\sigma(i)}(c^s)) + \lambda_2 \| b_i - \hat{b}{\sigma(i)} \|1 + \lambda_3 \text{IoU}(b_i, \hat{b}{\sigma(i)}$
- $\hat{p}_{\sigma(i)}(c^s))$: pseudo-class probabilities(이라고 설명하고 있는데 matched prediction으로 이해하면 될듯)
- $\mathbf{V}_{\mathbf{X}^s}$: 안쓰임
- $\mathbf{V}_{\mathbf{O}}$: 쓰임