热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

DETR特征图可视化代码

一共分为5个步骤,加载DETR模型及获取训练好的参数下载待检测的图片并进行预处理和前馈过程得到预测结果准备好前馈该图片时网络的各类参数(重点*

一共分为5个步骤,


  1. 加载DETR模型及获取训练好的参数
  2. 下载待检测的图片并进行预处理和前馈过程得到预测结果
  3. 准备好前馈该图片时网络的各类参数(重点*)
  4. 求attn_output_weigths以绘制各个head的注意力权重(重点*)
  5. 画图

在介绍具体的代码之前,有几个重要的变量解释如下:


变量名含义Shape
conv_featuresBackbone最后一层特征图[1,2048,25,34]
enc_attn_weights编码器最后一层的self_attn weights[1,850,850]
dec_attn_weights解码器最后一层的cross_attn weights[1,100,850]
memory编码器的输出/解码器的输入特征[850,1,256]
cq解码器最后一层self_attn的输出[100,1,256]
pk位置编码[1,256,25,34]
pq训练好的object queries,即query_embed[100,256]
in_proj_weight解码器最后一层cross_attn中q和k的线性权重[768,256]
in_proj_bias解码器最后一层cross_attn中q和k的偏置[768]


每个步骤的代码如下:


0. 准备工作

import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import requests
import matplotlib.pyplot as pltimport torch
import torchvision.transforms as T
from torch.nn.functional import linear,softmax
torch.set_grad_enabled(False)def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return b# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

1. 加载DETR模型及获取训练好的参数

# ----------------------------------------------1. 加载模型及获取训练好的参数---------------------------------------------------
# 加载线上的模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()
# 获取训练好的参数
for name, parameters in model.named_parameters():# 获取训练好的object queries,即pq:[100,256]if name == 'query_embed.weight':pq = parameters# 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768]if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight':in_proj_weight = parametersif name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias':in_proj_bias = parameters

2. 下载待检测的图片并进行预处理和前馈过程得到预测结果

# --------------------------------------------2.下载图像并进行预处理和前馈过程--------------------------------------------------
# 线上下载图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
# img_path = '/home/wujian/000000039769.jpg'
# im = Image.open(img_path)# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)# propagate through the model
outputs = model(img)# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

3. 准备好前馈该图片时网络的各类参数(重点*)

# ------------------------------------------------3. 准备存储前馈该图片时的值---------------------------------------------------
# use lists to store the outputs via up-values
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
cq = [] # 存储detr中的 cq
pk = [] # 存储detr中的 encoder pos
memory = [] # 编码器最后一层的输入/解码器的输入特征# 注册hook
hooks = [# 获取resnet最后一层特征图model.backbone[-2].register_forward_hook(lambda self, input, output: conv_features.append(output)),# 获取encoder的图像特征图memorymodel.transformer.encoder.register_forward_hook(lambda self, input, output: memory.append(output)),# 获取encoder的最后一层layer的self-attn weightsmodel.transformer.encoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: enc_attn_weights.append(output[1])),# 获取decoder的最后一层layer中交叉注意力的 weightsmodel.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_attn_weights.append(output[1])),# 获取decoder最后一层self-attn的输出cqmodel.transformer.decoder.layers[-1].norm1.register_forward_hook(lambda self, input, output: cq.append(output)),# 获取图像特征图的位置编码pkmodel.backbone[-1].register_forward_hook(lambda self, input, output: pk.append(output)),
]# propagate through the model
outputs = model(img)# 用完的hook后删除
for hook in hooks:hook.remove()# don't need the list anymore
conv_features = conv_features[0] # [1,2048,25,34]
enc_attn_weights = enc_attn_weights[0] # [1,850,850] : [N,L,S]
dec_attn_weights = dec_attn_weights[0] # [1,100,850] : [N,L,S] --> [batch, tgt_len, src_len]
memory = memory[0] # [850,1,256] # 编码器最后一层的输入/解码器的输入特征cq = cq[0] # decoder的self_attn:最后一层输出[100,1,256]
pk = pk[0] # [1,256,25,34]

4. 求attn_output_weigths以绘制各个head的注意力权重(重点*)

这里求attn_output_weigths的关键步骤为:

q=cq+pq

k=pk

q=linear(q, in_proj_weight, in_proj_bias)

k=linear(k, in_proj_weight, in_proj_bias)

attn_ouput_weights=torch.bmm(q,k) #[1,8,100,850]分别为8个head的注意力值

# ----------------------------------------4, 求attn_output_weights以绘制各个head的注意力权重------------------------------------
pk = pk.flatten(-2).permute(2,0,1) # [1,256,850] --> [850,1,256]
pq = pq.unsqueeze(1).repeat(1,1,1) # [100,1,256]
q = pq + cqk = pk# 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn()
_b = in_proj_bias
_start = 0
_end = 256
_w = in_proj_weight[_start:_end, :]
if _b is not None:_b = _b[_start:_end]
q = linear(q, _w, _b)_b = in_proj_bias
_start = 256
_end = 256 * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:_b = _b[_start:_end]
k = linear(k, _w, _b)scaling = float(256) ** -0.5
q = q * scaling
q = q.contiguous().view(100, 8, 32).transpose(0, 1)
k = k.contiguous().view(-1, 8, 32).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
attn_output_weights = attn_output_weights.view(1 * 8, 100, 850)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)# 后续可视化各个头
attn_every_heads = attn_output_weights # [1,8,100,850]
attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850]

5. 画图

# ----------------------------------------------------------5. 画图---------------------------------------------------------
h, w = conv_features['0'].tensors.shape[-2:]fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28)) # [11,2]
colors = COLORS * 100# 可视化
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):# 可视化decoder的注意力权重ax = ax_i[0]ax.imshow(dec_attn_weights[0, idx].view(h, w))ax.axis('off')ax.set_title(f'query id: {idx.item()}',fontsize = 30)# 可视化框和类别ax = ax_i[1]ax.imshow(im)ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color='blue', linewidth=3))ax.axis('off')ax.set_title(CLASSES[probas[idx].argmax()],fontsize = 30)# 分别可视化8个头部的位置特征图for head in range(2, 2 + 8):ax = ax_i[head]ax.imshow(attn_every_heads[0, head-2, idx].view(h,w))ax.axis('off')ax.set_title(f'head:{head-2}',fontsize = 30)fig.tight_layout() # 自动调整子图来使其填充整个画布
plt.show()

[注]:以上代码来自网络



可视化结果:

其中第一行的图就是用dec_attn_weights画出来的

 下面是8个head的可视化结果图,由attn_ouput_weights绘制

 


推荐阅读
author-avatar
mobiledu2502898543
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有