siamfc-pytorch代码讲解(三):demo&track

我之前的两篇博客:

今天主要看一下demo的部分,也就是涉及到测试tracking的部分。
直接上代码:

一、demo.py

from __future__ import absolute_import

import os
import glob
import numpy as np

from siamfc import TrackerSiamFC


if __name__ == '__main__':
    seq_dir = os.path.expanduser('D:\\OTB\\Crossing\\')
    img_files = sorted(glob.glob(seq_dir + 'img/*.jpg'))
    anno = np.loadtxt(seq_dir + 'groundtruth_rect.txt', delimiter=',')
    
    net_path = 'pretrained/siamfc_alexnet_e50.pth'
    tracker = TrackerSiamFC(net_path=net_path)
    tracker.track(img_files, anno[0], visualize=True)
  1. 上面的第11行路径自己该,我这次是windows测试的,所以这样写了(看着有点不规范)。
  2. 13行我多加了一点代码:, delimiter=',',不加这个会报这样的错:

ValueError: could not convert string to float

  1. 下面几行就是用训练好的siamfc_alexnet_e50.pth模型进行tracking,给定的是img_files:视频序列;anno[0]就是第一帧中的ground truth bbox。

二、track

现在就来看一下类TrackerSiamFC下的track方法。这个函数的作用就是传入video sequence和first frame中的ground truth bbox,然后通过模型,得到后续帧的目标位置,可以看到主要有两个函数实现:initupdate,这也是继承Tracker需要重写的两个方法:

  • init:就是传入第一帧的标签和图片,初始化一些参数,计算一些之后搜索区域的中心等等
  • update:就是传入后续帧,然后根据SiamFC网络返回目标的box坐标,之后就是根据这些坐标来show,起到一个demo的效果。
def track(self, img_files, box, visualize=False):
    frame_num = len(img_files)
    boxes = np.zeros((frame_num, 4))
    boxes[0] = box
    times = np.zeros(frame_num)

    for f, img_file in enumerate(img_files):
        img = ops.read_image(img_file)

        begin = time.time()
        if f == 0:
            self.init(img, box)
        else:
            boxes[f, :] = self.update(img)
        times[f] = time.time() - begin

        if visualize:
            ops.show_image(img, boxes[f, :])

    return boxes, times

2.1 init(self, img, box)

\color{#FF0000}{我强烈建议可以用两个设备,一个看代码,一个用来看我下边的长图,对照着分析}

def init(self, img, box):
    # set to evaluation mode
    self.net.eval()

    # convert box to 0-indexed and center based [y, x, h, w]
    box = np.array([
        box[1] - 1 + (box[3] - 1) / 2,
        box[0] - 1 + (box[2] - 1) / 2,
        box[3], box[2]], dtype=np.float32)
    self.center, self.target_sz = box[:2], box[2:]

    # create hanning window
    self.upscale_sz = self.cfg.response_up * self.cfg.response_sz  # 272
    self.hann_window = np.outer(
        np.hanning(self.upscale_sz),
        np.hanning(self.upscale_sz))
    self.hann_window /= self.hann_window.sum()

    # search scale factors
    self.scale_factors = self.cfg.scale_step ** np.linspace(
        -(self.cfg.scale_num // 2),
        self.cfg.scale_num // 2, self.cfg.scale_num)  # 1.0375**(-2,-0.5,1)

    # exemplar and search sizes
    context = self.cfg.context * np.sum(self.target_sz)
    self.z_sz = np.sqrt(np.prod(self.target_sz + context))
    self.x_sz = self.z_sz * \
                self.cfg.instance_sz / self.cfg.exemplar_sz

    # exemplar image
    self.avg_color = np.mean(img, axis=(0, 1))
    z = ops.crop_and_resize(
        img, self.center, self.z_sz,
        out_size=self.cfg.exemplar_sz,
        border_value=self.avg_color)

    # print(z.shape) # [127,127,3]
    # exemplar features [H,W,C]->[C,H,W]
    z = torch.from_numpy(z).to(
        self.device).permute(2, 0, 1).unsqueeze(0).float()
    self.kernel = self.net.backbone(z)  # torch.Size([1, 256, 6, 6])
  1. 一开始,就是把输入的ltwh格式的box转变为[y, x, h, w]格式的,这个看过我第二篇的就很清楚了,然后记录bbox的中心和宽高size信息,以备后用(如下图黑色字体表示的)
  2. 这里计算了响应图上采样后的大小upscale_sz,因为论文中有这样一句话:
    We found that upsampling the score map using bicubic interpolation, from 17 × 17 to 272 × 272, results in more accurate localization since the original map is relatively coarse.也就是17×16=272
  3. 然后创建了一个汉宁窗(hanning window),也叫余弦窗【可以看这里】,论文中说是增加惩罚:Online, ... and a cosine window is added to the score map to penalize large displacements
  4. 论文中提到两个变体,一个是5个尺度的,一个是3个尺度的(这里就是),5个尺度依次是1.025^{[-2,-1,0,1,2]},代码中3个尺度是1.0375^{[-2,-0.5,1]}
  5. context 就是边界的语义信息,为了计算z_szx_sz,最后送入crop_and_resize去抠出搜索区域【我第二篇博客有讲这个函数】,<font color=blue> z_sz大小可以看下面蓝色方形框</font>,<font color=#FF1493> x_sz大小可以看下面粉色方形框</font>,最后抠出z_sz大小的作为exemplar image,并送入backbone,输出embedding,也可以看作是一个固定的互相关kernel,为了之后的相似度计算用,如论文中提到:We found that updating (the feature representation of) the exemplar online through simple strategies, such as linear interpolation, does not gain much performance and thus we keep it fixed
    关于一些tensor的shape可以看代码里的注释,下面是我当时的笔记:
    博客笔记1

2.2 update(self, img)

\color{#FF0000}{我强烈建议可以用两个设备,一个看代码,一个用来看我下边的长图,对照着分析}

def update(self, img):
    # set to evaluation mode
    self.net.eval()

    # search images
    x = [ops.crop_and_resize(
        img, self.center, self.x_sz * f,
        out_size=self.cfg.instance_sz,
        border_value=self.avg_color) for f in self.scale_factors]
    x = np.stack(x, axis=0)  # [3, 255, 255, 3]
    x = torch.from_numpy(x).to(
        self.device).permute(0, 3, 1, 2).float()

    # responses
    x = self.net.backbone(x)  # [3, 256, 22, 22]
    responses = self.net.head(self.kernel, x)  # [3, 1, 17, 17]
    responses = responses.squeeze(1).cpu().numpy()  # [3, 17, 17]

    # upsample responses and penalize scale changes
    responses = np.stack([cv2.resize(
        u, (self.upscale_sz, self.upscale_sz),
        interpolation=cv2.INTER_CUBIC)
        for u in responses])  # [3, 272, 272]
    responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
    responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

    # peak scale
    scale_id = np.argmax(np.amax(responses, axis=(1, 2)))  # which channel is max

    # peak location
    response = responses[scale_id]
    response -= response.min()
    response /= response.sum() + 1e-16
    response = (1 - self.cfg.window_influence) * response + \
               self.cfg.window_influence * self.hann_window
    loc = np.unravel_index(response.argmax(), response.shape)

    # locate target center: disp stand for displacement
    disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
    disp_in_instance = disp_in_response * \
                       self.cfg.total_stride / self.cfg.response_up
    disp_in_image = disp_in_instance * self.x_sz * \
                    self.scale_factors[scale_id] / self.cfg.instance_sz
    self.center += disp_in_image

    # update target size
    scale = (1 - self.cfg.scale_lr) * 1.0 + \
            self.cfg.scale_lr * self.scale_factors[scale_id]
    self.target_sz *= scale
    self.z_sz *= scale
    self.x_sz *= scale

    # return 1-indexed and left-top based bounding box
    box = np.array([
        self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
        self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
        self.target_sz[1], self.target_sz[0]])

    return box
  1. update顾名思义就是对后续的帧更新出bbox来,因为是tracking phase,所以把模型设成eval mode。然后在这新的帧里抠出search images,根据之前init里生成的3个尺度,然后resize成255×255,特别一点,我们可以发现search images在resize之前的边长x_sz大约为target_sz的4倍,这也印证了论文中的:we only search for the object within a region of approximately four times its previous size
  2. 然后将这3个尺度的patch(也就是3个搜索范围)拼接一起,送入backbone,生成emdding后与之前的kernel进行互相关,得到score map,这些tensor的shape代码里都有标注,得到3个17×17的responses,然后对每一个response进行上采样到272×272
  3. 上面的24,25行就是对尺度进行惩罚,我是这样理解的,因为中间的尺度肯定是接近于1,其他两边的尺度不是缩一点就是放大一点,所以给以惩罚,如论文中说:Any change in scale is penalized
  4. 之后就选出这3个通道里面最大的那个,并就行归一化和余弦窗惩罚,然后通过numpy.unravel_index找到一张response上峰值点(peak location)【关于这个函数可以看这里
  5. 接下来的问题就是:在response图中找到峰值点,那这在原图img中在哪里呢?所以我们要计算位移(displacement),因为我们原本都是以目标为中心的,认为最大峰值点应该在response的中心,所以39行就是峰值点和response中心的位移。
  6. 因为之前在img上crop下一块instance patch,然后resize,然后送入CNN的backbone,然后score map又进行上采样成response,所以要根据这过程,逆回去计算对应在img上的位移,所以上面的39-43行就是在做这件事,也可以看下面的图
  7. 根据disp_in_image修正center,然后update target size,因为论文有一句:update the scale by linear interpolation with a factor of 0.35 to provide damping,但是似乎参数不太对得上,线性插值可以看下面<font color=#00BFFF>蓝色</font>的图,因为更新后的scale还是很接近1,所以bbox区域不会变化很大
  8. 最后根据ops.show_image输入的需要,又得把bbox格式改回ltwh的格式
    博客笔记2

三、checkpoint and demo

我的模型存在这里,但是只训练了GOT-10k的前500个序列,但感觉效果也还行:

Crossing-demo

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,185评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,445评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,684评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,564评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,681评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,874评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,025评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,761评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,217评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,545评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,694评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,351评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,988评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,778评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,007评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,427评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,580评论 2 349

推荐阅读更多精彩内容