热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用关键点进行小目标检测

【GiantPandaCV导语】本文是笔者出于兴趣搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要采用了回归和heat

【GiantPandaCV导语】本文是笔者出于兴趣搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要采用了回归和heatmap两种方式来回归关键点,是一个很简单基础的项目,代码量很小,可供新手学习。

1. 数据来源

数据集:数据来源自小武,经过小武的授权使用,但不会公开。本项目只用了其中很少一部分共108张图片。

标注工具:https://github.com/pprp/landmark_annotation

标注工具也可以在GiantPandaCV公众号后台回复“landmark”关键字获取

部分样例展示

上图是数据集中的两张图片,红圈代表对应的目标,标注的时候只需要在其中心点一下即可得到该点对应的横纵坐标。

该数据集有一个特点,每张图只有一个目标(不然没法用简单的方法回归),多余一个目标的图片被剔除了。

1
0.42 0.596

以上是一个标注文件的例子,1.jpg对应1.txt

2. 回归确定关键点

回归确定关键点比较简单,网络部分采用手工构建的一个两层的小网络,训练采用的是MSELoss。

这部分代码在:https://github.com/pprp/SimpleCVReproduction/tree/master/simple_keypoint/regression

2.1 数据加载

数据的组织比较简单,按照以下格式组织:

- data
	- images
		- 1.jpg
		- 2.jpg
		- ...
	- labels
		- 1.txt
		- 2.txt
		- ...

重写一下Dataset类,用于加载数据集。

class KeyPointDatasets(Dataset):
    def __init__(self, root_dir="./data", transforms=None):
        super(KeyPointDatasets, self).__init__()
        self.img_path = os.path.join(root_dir, "images")
        # self.txt_path = os.path.join(root_dir, "labels")

        self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))
        self.txt_list = [item.replace(".jpg", ".txt").replace(
            "images", "labels") for item in self.img_list]

        if transforms is not None:
            self.transforms = transforms

    def __getitem__(self, index):
        img = self.img_list[index]
        txt = self.txt_list[index]

        img = cv2.imread(img)

        if self.transforms:
            img = self.transforms(img)

        label = []

        with open(txt, "r") as f:
            for i, line in enumerate(f):
                if i == 0:
                    # 第一行
                    num_point = int(line.strip())
                else:
                    x1, y1 = [(t.strip()) for t in line.split()]
                    # range from 0 to 1
                    x1, y1 = float(x1), float(y1)

                    tmp_label = (x1, y1)
                    label.append(tmp_label)

        return img, torch.tensor(label[0])

    def __len__(self):
        return len(self.img_list)

    @staticmethod
    def collect_fn(batch):
        imgs, labels = zip(*batch)
        return torch.stack(imgs, 0), torch.stack(labels, 0)

返回的结果是图片和对应坐标位置。

2.2 网络模型

import torch
import torch.nn as nn

class KeyPointModel(nn.Module):
    def __init__(self):
        super(KeyPointModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(6)
        self.relu1 = nn.ReLU(True)
        self.maxpool1 = nn.MaxPool2d((2, 2))

        self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(12)
        self.relu2 = nn.ReLU(True)
        self.maxpool2 = nn.MaxPool2d((2, 2))

        self.gap = nn.AdaptiveMaxPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(12, 2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.gap(x)
        x = x.view(x.shape[0], -1)
        return self.classifier(x)

其结构就是卷积+pooling+卷积+pooling+global average pooling+Linear,返回长度为2的tensor。

2.3 训练

def train(model, epoch, dataloader, optimizer, criterion):
    model.train()
    for itr, (image, label) in enumerate(dataloader):
        bs = image.shape[0]
        output = model(image)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if itr % 4 == 0:
            print("epoch:%2d|step:%04d|loss:%.6f" % (epoch, itr, loss.item()/bs))
            vis.plot_many_stack({"train_loss": loss.item()*100/bs})
            
total_epoch = 300
bs = 10
########################################
transforms_all = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((360,480)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4372, 0.4372, 0.4373],
                         std=[0.2479, 0.2475, 0.2485])
])

datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)

data_loader = DataLoader(datasets, shuffle=True,
                         batch_size=bs, collate_fn=datasets.collect_fn)

model = KeyPointModel()

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# criterion = torch.nn.SmoothL1Loss()
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=30,
                                            gamma=0.1)

for epoch in range(total_epoch):
    train(model, epoch, data_loader, optimizer, criterion)
    loss = test(model, epoch, data_loader, criterion)

    if epoch % 10 == 0:
        torch.save(model.state_dict(),
                   "weights/epoch_%d_%.3f.pt" % (epoch, loss*1000))

loss部分使用Smooth L1 loss或者MSE loss均可。

MSE Loss:

\[loss(x,y)=\frac{1}{n}\sum(x_i-y_i)^2 \]

Smooth L1 Loss:

\[smooth_{L_1}(x)= \begin{cases} 0.5x^2 & if |x|<1 \\ |x|-0.5 & otherwise \end{cases} \]

2.4 测试结果

3. heatmap确定关键点

这部分代码很多参考了CenterNet,不过曾经尝试CenterNet中的loss在这个问题上收敛效果不好,所以参考了kaggle人脸关键点定位的解决方法,发现使用简单的MSELoss效果就很好。

3.1 数据加载

这部分和CenterNet构建heatmap的过程类似,不过半径的确定是人工的。因为数据集中的目标都比较小,半径的范围最大不超过半径为30个像素的圆。

class KeyPointDatasets(Dataset):
    def __init__(self, root_dir="./data", transforms=None):
        super(KeyPointDatasets, self).__init__()

        self.down_ratio = 1
        self.img_w = 480 // self.down_ratio
        self.img_h = 360 // self.down_ratio

        self.img_path = os.path.join(root_dir, "images")

        self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))
        self.txt_list = [item.replace(".jpg", ".txt").replace(
            "images", "labels") for item in self.img_list]

        if transforms is not None:
            self.transforms = transforms

    def __getitem__(self, index):
        img = self.img_list[index]
        txt = self.txt_list[index]

        img = cv2.imread(img)

        if self.transforms:
            img = self.transforms(img)

        label = []

        with open(txt, "r") as f:
            for i, line in enumerate(f):
                if i == 0:
                    # 第一行
                    num_point = int(line.strip())
                else:
                    x1, y1 = [(t.strip()) for t in line.split()]
                    # range from 0 to 1
                    x1, y1 = float(x1), float(y1)
                    cx, cy = x1 * self.img_w, y1 * self.img_h
                    heatmap = np.zeros((self.img_h, self.img_w))
                    draw_umich_gaussian(heatmap, (cx, cy), 30)
        return img, torch.tensor(heatmap).unsqueeze(0)

    def __len__(self):
        return len(self.img_list)

    @staticmethod
    def collect_fn(batch):
        imgs, labels = zip(*batch)
        return torch.stack(imgs, 0), torch.stack(labels, 0)

核心函数是draw_umich_gaussian,具体如下:

def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h  0 and min(masked_heatmap.shape) > 0:  # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
        # 将高斯分布覆盖到heatmap上,取最大,而不是叠加
    return heatmap

sigma参数直接沿用了CenterNet中的设置,没有调节这个超参数。

3.2 网络结构

网络结构参考了知乎上一个复现YOLOv3中提到的模块,Sematic Embbed Block(SEB)用于上采样部分,将来自低分辨率的特征图进行上采样,然后使用3x3卷积和1x1卷积统一通道个数,最后将低分辨率特征图和高分辨率特征图相乘得到融合结果。

class SematicEmbbedBlock(nn.Module):
    def __init__(self, high_in_plane, low_in_plane, out_plane):
        super(SematicEmbbedBlock, self).__init__()
        self.conv3x3 = nn.Conv2d(high_in_plane, out_plane, 3, 1, 1)
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)

        self.conv1x1 = nn.Conv2d(low_in_plane, out_plane, 1)

    def forward(self, high_x, low_x):
        high_x = self.upsample(self.conv3x3(high_x))
        low_x = self.conv1x1(low_x)
        return high_x * low_x


class KeyPointModel(nn.Module):
    """
    downsample ratio=2
    """

    def __init__(self):
        super(KeyPointModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(6)
        self.relu1 = nn.ReLU(True)
        self.maxpool1 = nn.MaxPool2d((2, 2))

        self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(12)
        self.relu2 = nn.ReLU(True)
        self.maxpool2 = nn.MaxPool2d((2, 2))

        self.conv3 = nn.Conv2d(12, 20, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(20)
        self.relu3 = nn.ReLU(True)
        self.maxpool3 = nn.MaxPool2d((2, 2))

        self.conv4 = nn.Conv2d(20, 40, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(40)
        self.relu4 = nn.ReLU(True)

        self.seb1 = SematicEmbbedBlock(40, 20, 20)
        self.seb2 = SematicEmbbedBlock(20, 12, 12)
        self.seb3 = SematicEmbbedBlock(12, 6, 6)

        self.heatmap = nn.Conv2d(6, 1, 1)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x1 = self.relu1(x1)

        m1 = self.maxpool1(x1)

        x2 = self.conv2(m1)
        x2 = self.bn2(x2)
        x2 = self.relu2(x2)

        m2 = self.maxpool2(x2)

        x3 = self.conv3(m2)
        x3 = self.bn3(x3)
        x3 = self.relu3(x3)

        m3 = self.maxpool3(x3)

        x4 = self.conv4(m3)
        x4 = self.bn4(x4)
        x4 = self.relu4(x4)

        up1 = self.seb1(x4, x3)
        up2 = self.seb2(up1, x2)
        up3 = self.seb3(up2, x1)

        out = self.heatmap(up3)
        return out

网络模型也是自己写的小网络,用了四个卷积层,三个池化层,然后进行了三次上采样。最终输出分辨率和输入分辨率相同。

3.3 训练过程

训练过程和基于回归的方法几乎一样,代码如下:

datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)

data_loader = DataLoader(datasets, shuffle=True,
                         batch_size=bs, collate_fn=datasets.collect_fn)

model = KeyPointModel()

if torch.cuda.is_available():
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
criterion = torch.nn.MSELoss()  # compute_loss
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=30,
                                                gamma=0.1)

for epoch in range(total_epoch):
    train(model, epoch, data_loader, optimizer, criterion, scheduler)
    loss = test(model, epoch, data_loader, criterion)

    if epoch % 5 == 0:
        torch.save(model.state_dict(),
                   "weights/epoch_%d_%.3f.pt" % (epoch, loss*10000))

用的是MSELoss进行监督,训练曲线如下:

训练过程中的loss曲线

3.4 测试过程

测试过程和CenterNet的推理过程一致,也用到了3x3的maxpooling来筛选极大值点

for iter, (image, label) in enumerate(dataloader):
    # print(image.shape)
    bs = image.shape[0]
    hm = model(image)
    hm = _nms(hm)
    hm = hm.detach().numpy()

    for i in range(bs):
        hm = hm[i]
        hm = np.maximum(hm, 0)
        hm = hm/np.max(hm)
        hm = normalization(hm)
        hm = np.uint8(255 * hm)
        hm = hm[0]
        # heatmap = torch.sigmoid(heatmap)
        # hm = cv2.cvtColor(hm, cv2.COLOR_RGB2BGR)
        hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
        cv2.imwrite("./test_output/output_%d_%d.jpg" % (iter, i), hm)
        cv2.waitKey(0)

以上的nms和topk代码都在CenterNet系列最后一篇讲过了。这里直接对模型输出结果使用nms,然后进行可视化,结果如下:

放大结果

上图中白色的点就是目标位置,为了更形象的查看结果,detect.py部分负责可视化。

3.5 可视化

可视化的问题经常遇见,比如CAM、Grad CAM等可视化特征图的时候就会碰到。以下是可视化的一个简单的方法(参考了CSDN的一位博主的方案,具体链接因太过久远找不到了)。

可视化流程

具体实现代码如下:

def normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range

heatmap = model(img_tensor_list)
heatmap = heatmap.squeeze().cpu()

for i in range(bs):
    img_path = img_list[i]
    img = cv2.imread(img_path)
    img = cv2.resize(img, (480, 360))
    single_map = heatmap[i]
    hm = single_map.detach().numpy()
    hm = np.maximum(hm, 0)
    hm = hm/np.max(hm)
    hm = normalization(hm)
    hm = np.uint8(255 * hm)
    hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
    hm = cv2.resize(hm, (480, 360))
    superimposed_img = hm * 0.2 + img
    coord_x, coord_y = landmark_coord[i]
    cv2.circle(superimposed_img, (int(coord_x), int(coord_y)), 2, (0, 0, 0), thickness=-1)
    cv2.imwrite("./output2/%s_out.jpg" % (img_name_list[i]), superimposed_img)

注意通过处理以后的hm和原图叠加的时候0.2只是一个参考值,这个值既不会影响原图显示又能将heatmap中重点关注的位置可视化出来。

结果如下:

可视化结果

可以看到,定位结果要比回归更准一些,图中黑色点是获取到最终坐标的位置,几乎和目标是重叠的状态,效果比较理想。

4. 总结

笔者做这个小项目初心是想搞清楚如何用关键点进行定位的,关键点被用在很多领域比如人脸关键点定位、车牌定位、人体姿态检测、目标检测等等领域。当时用小武的数据的时候,发现这个数据集的特点就是目标很小,比较适合用关键点来做。之后又开始陆陆续续看CenterNet源码,借鉴了其中很多代码,这才完成了这个小项目。

由于本人水平有限,可能使用heatmap进行关键点定位的方式有些地方并不合理,是东拼西凑而成的,如果有建议可以添加笔者微信topeijie。


推荐阅读
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • Windows7 64位系统安装PLSQL Developer的步骤和注意事项
    本文介绍了在Windows7 64位系统上安装PLSQL Developer的步骤和注意事项。首先下载并安装PLSQL Developer,注意不要安装在默认目录下。然后下载Windows 32位的oracle instant client,并解压到指定路径。最后,按照自己的喜好对解压后的文件进行命名和压缩。 ... [详细]
  • 突破MIUI14限制,自定义胶囊图标、大图标样式,支持任意APP
    本文介绍了如何突破MIUI14的限制,实现自定义胶囊图标和大图标样式,并支持任意APP。需要一定的动手能力和主题设计师账号权限或者会主题pojie。详细步骤包括应用包名获取、素材制作和封包获取等。 ... [详细]
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • HashMap的相关问题及其底层数据结构和操作流程
    本文介绍了关于HashMap的相关问题,包括其底层数据结构、JDK1.7和JDK1.8的差异、红黑树的使用、扩容和树化的条件、退化为链表的情况、索引的计算方法、hashcode和hash()方法的作用、数组容量的选择、Put方法的流程以及并发问题下的操作。文章还提到了扩容死链和数据错乱的问题,并探讨了key的设计要求。对于对Java面试中的HashMap问题感兴趣的读者,本文将为您提供一些有用的技术和经验。 ... [详细]
  • 本文介绍了如何使用n3-charts绘制以日期为x轴的数据,并提供了相应的代码示例。通过设置x轴的类型为日期,可以实现对日期数据的正确显示和处理。同时,还介绍了如何设置y轴的类型和其他相关参数。通过本文的学习,读者可以掌握使用n3-charts绘制日期数据的方法。 ... [详细]
  • PatchODAX8: ... [详细]
  • IhaveconfiguredanactionforaremotenotificationwhenitarrivestomyiOsapp.Iwanttwodiff ... [详细]
  • 本文讨论了在Windows 8上安装gvim中插件时出现的错误加载问题。作者将EasyMotion插件放在了正确的位置,但加载时却出现了错误。作者提供了下载链接和之前放置插件的位置,并列出了出现的错误信息。 ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • Sleuth+zipkin链路追踪SpringCloud微服务的解决方案
    在庞大的微服务群中,随着业务扩展,微服务个数增多,系统调用链路复杂化。Sleuth+zipkin是解决SpringCloud微服务定位和追踪的方案。通过TraceId将不同服务调用的日志串联起来,实现请求链路跟踪。通过Feign调用和Request传递TraceId,将整个调用链路的服务日志归组合并,提供定位和追踪的功能。 ... [详细]
  • 1关于字符串相邻的两个或多个字符串字面值(引号引起来的字符)将会自动连接到一起:str_catpython!str_cat输出:python!把很长 ... [详细]
author-avatar
陈亮跑街向_114
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有