NeurIPS 2024 Accepted Paper
Link: https://arxiv.org/pdf/2410.19635
Summary
Introduction or Motivation
- Understanding an image at both global and local levels is a key factor for a wide range of vision perception tasks.
- Conflict in Object Detection:
- There is always a conflict between detecting the whole object and its parts.
- This conflict arises because parts of the object are also annotated in many scenarios, such as when objects are under occlusion.
- Co-occurrence of Objects:
- The co-occurrence of objects can facilitate the detection of some missing objects.
- Conflict in Object Detection:
- Frozen-DETR
- Functionality:
- Utilizes a frozen foundation model as a plug-and-play module to boost the performance of query-based detectors.
- Foundation Model as Feature Enhancer:
- Global Image Understanding:
- Uses the class token from the foundation model as the full image representation, termed image query.
- Fine-Grained Patch Tokens:
- Considers fine-grained patch tokens with high-level semantic cues from the foundation model as another level of the feature pyramid.
- These patch tokens are fused with the detector’s feature pyramid via the encoder layers.
- Global Image Understanding:
- Advantages:
- No architecture constraints.
- Plug-and-play capability.
- Asymmetric input size
- Functionality:
Method
Enhancing Decoder by Treating Class Tokens as Image Queries
- Characteristics:
- Pre-trained foundation models have a strong ability to understand complex images at a global level
- Take advantage of their scene-understanding ability
- Treat the class token as the image query.
- With the image query as context, object queries can be better classified.
- Flow:
- project the image query to the same dimensions as object queries
- concatenate these two kinds of queries before feeding them into the self-attention module.
- Do self-attention in decoder layer
- Discard the image query right after self-attention
- Do cross-attention
- Multiple Image Queries:
- Global feature query like conventional cls token
- grouping and meaning tokens to get multiple cls tokens like crop image cls token or via using attention mask
Enhancing Encoder by Feature Fusion
- 추가적인 Pyramid feature를 down sample해서 concat 해서 사용하는 것과 동일하게 project을 통해서 dimension을 맞춰준 후 전체 feature concat해 DETR’s encoder에 넣어 inference하고 decoder에 넣기 전에 decoupling한다.
Code Analysis
- How to get cropped image?
- CLIP’s visual encoder에 attn_mask를 같이 보내 cls tokens 및 feature map을 획득한다.
@torch.no_grad()
def get_clip_image_feat_multi_cls_token(self, imgs, boxes, img_metas):
preprocessed = []
for i in range(len(boxes)):
img = imgs[i]
device = boxes[i].device
boxs = boxes[i]
if len(boxs) == 0:
continue
img_shape = img_metas[i] # (W, H)
boxs = torch.stack([torch.floor(boxs[:,0]-0.001),torch.floor(boxs[:,1]-0.001),torch.ceil(boxs[:,2]),torch.ceil(boxs[:,3])], dim=1).to(torch.int)
boxs[:,[0,2]].clamp_(min=0,max=img_shape[0])
boxs[:,[1,3]].clamp_(min=0,max=img_shape[1])
boxs = boxs.detach().cpu().numpy()
for i, box in enumerate(boxs):
croped = img.crop(box)
croped = self.preprocess(croped)
preprocessed.append(croped)
if len(preprocessed) == 0:
return torch.zeros((0, self.global_token_dim), device=device)
preprocessed = torch.stack(preprocessed).to(device)
self.foundation_model.eval()
features, patch_token = self.foundation_model.forward_multi_cls_token(preprocessed, self.attn_mask, self.num_global_token)
features = torch.cat([features[:, self.num_global_token-1:], features[:, :self.num_global_token-1]], dim=1)
b, _, c = patch_token.shape
return features, patch_token.permute(0, 2, 1).reshape(b, c, 24, 24)
...
image_box = self.generate_image_box(img_metas, torch.float, samples.device)
single_image_box = [box[:1] for box in image_box]
image_query, patch_token1 = self.get_clip_image_feat_multi_cls_token(img_no_normalize, single_image_box, img_metas)
- How to use cls token in decoder?
- 매 decoder layer에 대응되는 projection layer를 통해 image_query를 project하고 concat해서 사용한다.
- 논문의 figure3과는 다르게 마지막 regression 및 classification layer를 지날 때 decoupling한다.
- Attention Mask
- 오른쪽의 attn mask를 4개 생성하고, 이를 reshape 및 flatten을 처리하면 아래의 기다란 attention mask를 생성할 수 있다.
- 아래의 attention mask의 경우에는 crop해 cls token 영역이 잘 보이도록 했다.
- 매 decoder layer에 대응되는 projection layer를 통해 image_query를 project하고 concat해서 사용한다.
for layer_id, layer in enumerate(self.layers):
if image_query is not None:
image_query_per_layer = self.image_query_norm[layer_id](self.image_query_proj[layer_id](image_query))
output = torch.cat([output, image_query_per_layer.permute(1, 0, 2)], dim=0)
reference_points = torch.cat([reference_points, image_box.permute(1, 0, 2)], dim=0)
...
if image_query is not None:
output = output[:num_query]
reference_points = reference_points[:num_query]