본문으로 바로가기

NeurIPS 2024 Accepted Paper

Link: https://arxiv.org/pdf/2410.19635

 

Summary

Introduction or Motivation

  • Understanding an image at both global and local levels is a key factor for a wide range of vision perception tasks.
    • Conflict in Object Detection:
      1. There is always a conflict between detecting the whole object and its parts.
      2. This conflict arises because parts of the object are also annotated in many scenarios, such as when objects are under occlusion.
    • Co-occurrence of Objects:
      • The co-occurrence of objects can facilitate the detection of some missing objects.
  • Frozen-DETR
    • Functionality:
      • Utilizes a frozen foundation model as a plug-and-play module to boost the performance of query-based detectors.
    • Foundation Model as Feature Enhancer:
      • Global Image Understanding:
        • Uses the class token from the foundation model as the full image representation, termed image query.
      • Fine-Grained Patch Tokens:
        • Considers fine-grained patch tokens with high-level semantic cues from the foundation model as another level of the feature pyramid.
        • These patch tokens are fused with the detector’s feature pyramid via the encoder layers.
    • Advantages:
      1. No architecture constraints.
      2. Plug-and-play capability.
      3. Asymmetric input size

Method

Enhancing Decoder by Treating Class Tokens as Image Queries

  • Characteristics:
    • Pre-trained foundation models have a strong ability to understand complex images at a global level
    • Take advantage of their scene-understanding ability
    • Treat the class token as the image query.
    • With the image query as context, object queries can be better classified.
  • Flow:
    • project the image query to the same dimensions as object queries
    • concatenate these two kinds of queries before feeding them into the self-attention module.
    • Do self-attention in decoder layer
    • Discard the image query right after self-attention
    • Do cross-attention
  • Multiple Image Queries:
    • Global feature query like conventional cls token
    • grouping and meaning tokens to get multiple cls tokens like crop image cls token or via using attention mask

Enhancing Encoder by Feature Fusion

  • 추가적인 Pyramid feature를 down sample해서 concat 해서 사용하는 것과 동일하게 project을 통해서 dimension을 맞춰준 후 전체 feature concat해 DETR’s encoder에 넣어 inference하고 decoder에 넣기 전에 decoupling한다.

Code Analysis

  • How to get cropped image?
    • CLIP’s visual encoder에 attn_mask를 같이 보내 cls tokens 및 feature map을 획득한다.
@torch.no_grad()
def get_clip_image_feat_multi_cls_token(self, imgs, boxes, img_metas):
    preprocessed = []
    
    for i in range(len(boxes)):
        img = imgs[i]
        device = boxes[i].device
        boxs = boxes[i]
        if len(boxs) == 0:
            continue

        img_shape = img_metas[i] # (W, H)

        boxs = torch.stack([torch.floor(boxs[:,0]-0.001),torch.floor(boxs[:,1]-0.001),torch.ceil(boxs[:,2]),torch.ceil(boxs[:,3])], dim=1).to(torch.int)
        boxs[:,[0,2]].clamp_(min=0,max=img_shape[0])
        boxs[:,[1,3]].clamp_(min=0,max=img_shape[1])

        boxs = boxs.detach().cpu().numpy()
        
        for i, box in enumerate(boxs):
            croped = img.crop(box)
            croped = self.preprocess(croped)
            preprocessed.append(croped)

    if len(preprocessed) == 0:
        return torch.zeros((0, self.global_token_dim), device=device)

    preprocessed = torch.stack(preprocessed).to(device)
    self.foundation_model.eval()
    features, patch_token = self.foundation_model.forward_multi_cls_token(preprocessed, self.attn_mask, self.num_global_token)
    features = torch.cat([features[:, self.num_global_token-1:], features[:, :self.num_global_token-1]], dim=1)
    b, _, c = patch_token.shape
    
    return features, patch_token.permute(0, 2, 1).reshape(b, c, 24, 24)
 
...

image_box = self.generate_image_box(img_metas, torch.float, samples.device)
single_image_box = [box[:1] for box in image_box]
image_query, patch_token1 = self.get_clip_image_feat_multi_cls_token(img_no_normalize, single_image_box, img_metas)

 

  • How to use cls token in decoder?
    •  매 decoder layer에 대응되는 projection layer를 통해 image_query를 project하고 concat해서 사용한다.
      • 논문의 figure3과는 다르게 마지막 regression 및 classification layer를 지날 때 decoupling한다.
    • Attention Mask

      • 오른쪽의 attn mask를 4개 생성하고, 이를 reshape 및 flatten을 처리하면 아래의 기다란 attention mask를 생성할 수 있다.
      • 아래의 attention mask의 경우에는 crop해 cls token 영역이 잘 보이도록 했다.

 

for layer_id, layer in enumerate(self.layers):
	if image_query is not None:
	  image_query_per_layer = self.image_query_norm[layer_id](self.image_query_proj[layer_id](image_query))
	  output = torch.cat([output, image_query_per_layer.permute(1, 0, 2)], dim=0)
    reference_points = torch.cat([reference_points, image_box.permute(1, 0, 2)], dim=0)
	  ...
  if image_query is not None:
    output = output[:num_query]
    reference_points = reference_points[:num_query]

Experimental Results


MisoYuri's Deck
블로그 이미지 MisoYuri 님의 블로그
VISITOR 오늘 / 전체