人工智能视频教程 ai vip技术 人工智能数学基础 爬虫 python机器学习 tensorflow深度学习 20+个企业AI实战项目

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 230|回复: 0

[课堂笔记] 先进驾驶辅助系统ADAS业务实战项目总结

[复制链接]

803

主题

1003

帖子

9988

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
9988
QQ
发表于 2019-9-4 20:36:38 | 显示全部楼层 |阅读模式
ADAS业务场景综述
先进驾驶辅助系统( Advanced Driver Assistance System)
简称ADAS,是利用安装于车上的各式各样的传感器,在第一时
间收集车内外的环境数据,进行静、动态物体的辨识、侦测与追
踪等技术上的处理,从而能够让驾驶者在最快的时间察觉可能发
生的危险,以引起注意和提高安全性的主动安全技术。


ADAS业务的研究方向:
导航与实时交通系统TMC
自适应灯光控制
电子警察系统
行人保护系统
车联网
自动泊车系统
自适应巡航ACC
交通标志识别
盲点探测
车道偏移报警系统
驾驶员疲劳探测
车道保持系统
下坡控制系统
碰撞避免或预碰撞系统
电动汽车报警系统
夜视系统


主要任务:检测车载视频数据中的机动车、非机动车、行人、交通标识符
标准的目标检测问题


判断算法性能好坏:
1、检出率 误报率
每一个标记只允许有一个检测与之相对应
重复检测会被视为错误检测
2、AP和mAP


ADAS业务场景的数据集:
Kitti数据集  MOT数据集  Berkeley大规模自动驾驶视频数据集


机动车、非机动车、行人检测问题难点(户外的):
1、阴天、雨天、夜间目标检测问题
2、拥挤场景下的目标检测问题
3、行人刚性运动带来的检测难题
4、小目标检测问题
5、遮挡问题 等等


KiTTi数据集(有图展示)
下载链接:http://www.cvlibs.net/datasets/k ... hp?obj_benchmark=2d
KiTTi数据集由德国卡尔斯鲁厄理工学院和丰田美国技术研究院联合创办,
是目前国际上知名的自动驾驶场景下的计算机视觉算法评测数据集。


该数据集用于评测立体图像,光流,视觉测距,3D物体检测和3D跟踪等计算机视觉技术在车载环境下的性能。
Kitti包含市区、乡村和高速公路等场景采集的真实图像数据,每张图像中最多达15辆车和30个行人,还有各种程度的遮挡与截断。
整个数据集由389对立体图像和光流图,39.2km视觉测距序列以及超过200k3D标注物体的图像组成,以10Hz的频率采样及同步。

查看Kitti数据集中标记文件目标框的所有类别:
  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/8/9 下午5:19'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'demo_category'

  6. import glob

  7. list_anno_files = glob.glob('/home/dfy888/DataSets/Kitti_voc/training/label_2/*')

  8. # print list_anno_files

  9. # 7481
  10. print len(list_anno_files)

  11. category_list = []

  12. for file_path in list_anno_files:
  13.     with open(file_path) as f:
  14.         anno_infos = f.readlines()
  15.         # print anno_infos
  16.         for anno_item in anno_infos:
  17.             category_list.append(anno_item.split(' ')[0])

  18. print '查看Kitti数据集中标记文件目标框的所有类别:'
  19. # ['Cyclist', 'Van', 'Tram', 'Car', 'Misc', 'Pedestrian', 'Truck', 'Person_sitting', 'DontCare']
  20. # 后续打包数据的时候会将Misc杂项 DontCare不关心 去掉 不会写入voc格式的xml文件中
  21. print set(category_list)
  22. print len(category_list)

复制代码


Kitti数据集转换为VOC格式数据集:其他的标注信息:可以考虑截断和遮挡 比如遮挡严重的目标框可以过滤掉 保证样本质量更好一些
但是在模型预测的时候会对遮挡严重的目标产生漏检 这个需要自己权衡
  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/7/15 下午3:23'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'kitti2voc'

  6. import cv2
  7. import glob
  8. from xml.dom.minidom import Document

  9. list_anno_files = glob.glob('/home/dfy888/DataSets/Kitti_voc/training/label_2/*')


  10. def writexml(filename, saveimg, bboxes, xmlpath, typename):
  11.     """
  12.     写成voc格式通用的xml文件
  13.     :param filename: 图片的路径
  14.     :param saveimg: 图片对象 cv2
  15.     :param bboxes: 多个人脸框集合
  16.     :param xmlpath: xml文件路径
  17.     :return:
  18.     """
  19.     doc = Document()
  20.     # 根节点
  21.     annotation = doc.createElement('annotation')
  22.     doc.appendChild(annotation)

  23.     folder = doc.createElement('folder')
  24.     # 注意:widerface_voc voc格式数据的文件夹名字
  25.     folder_name = doc.createTextNode('Kitti_voc')
  26.     folder.appendChild(folder_name)
  27.     annotation.appendChild(folder)

  28.     filenamenode = doc.createElement('filename')
  29.     filename_name = doc.createTextNode(filename)
  30.     filenamenode.appendChild(filename_name)
  31.     annotation.appendChild(filenamenode)

  32.     source = doc.createElement('source')
  33.     annotation.appendChild(source)

  34.     database = doc.createElement('database')
  35.     database.appendChild(doc.createTextNode('Kitti Database'))
  36.     source.appendChild(database)

  37.     annotation_s = doc.createElement('annotation')
  38.     annotation_s.appendChild(doc.createTextNode('PASCAL VOC2007'))
  39.     source.appendChild(annotation_s)

  40.     image = doc.createElement('image')
  41.     image.appendChild(doc.createTextNode('flickr'))
  42.     source.appendChild(image)

  43.     flickrid = doc.createElement('flickrid')
  44.     flickrid.appendChild(doc.createTextNode('-1'))
  45.     source.appendChild(flickrid)

  46.     owner = doc.createElement('owner')
  47.     annotation.appendChild(owner)

  48.     flickrid_o = doc.createElement('flickrid')
  49.     flickrid_o.appendChild(doc.createTextNode('dfy_88888'))
  50.     owner.appendChild(flickrid_o)

  51.     name_o = doc.createElement('name')
  52.     name_o.appendChild(doc.createTextNode('dfy_88888'))
  53.     owner.appendChild(name_o)

  54.     size = doc.createElement('size')
  55.     annotation.appendChild(size)

  56.     width = doc.createElement('width')
  57.     width.appendChild(doc.createTextNode(str(saveimg.shape[1])))
  58.     height = doc.createElement('height')
  59.     height.appendChild(doc.createTextNode(str(saveimg.shape[0])))
  60.     depth = doc.createElement('depth')
  61.     depth.appendChild(doc.createTextNode(str(saveimg.shape[2])))
  62.     size.appendChild(width)
  63.     size.appendChild(height)
  64.     size.appendChild(depth)

  65.     segmented = doc.createElement('segmented')
  66.     segmented.appendChild(doc.createTextNode('0'))
  67.     annotation.appendChild(segmented)

  68.     for i in range(len(bboxes)):
  69.         # bbox 四维向量: [左上角坐标x y 宽高 w h]
  70.         bbox = bboxes[i]
  71.         objects = doc.createElement('object')
  72.         annotation.appendChild(objects)

  73.         object_name = doc.createElement('name')
  74.         # 不是只有人脸 adas业务场景下 行人 车辆 交通标示
  75.         object_name.appendChild(doc.createTextNode(typename[i]))
  76.         objects.appendChild(object_name)

  77.         pose = doc.createElement('pose')
  78.         pose.appendChild(doc.createTextNode('Unspecified'))
  79.         objects.appendChild(pose)

  80.         truncated = doc.createElement('truncated')
  81.         truncated.appendChild(doc.createTextNode('1'))
  82.         objects.appendChild(truncated)

  83.         difficult = doc.createElement('difficult')
  84.         difficult.appendChild(doc.createTextNode('0'))
  85.         objects.appendChild(difficult)

  86.         bndbox = doc.createElement('bndbox')
  87.         objects.appendChild(bndbox)
  88.         # xmin ymin 就是标记框 左上角的坐标
  89.         xmin = doc.createElement('xmin')
  90.         xmin.appendChild(doc.createTextNode(str(bbox[0])))
  91.         bndbox.appendChild(xmin)
  92.         ymin = doc.createElement('ymin')
  93.         ymin.appendChild(doc.createTextNode(str(bbox[1])))
  94.         bndbox.appendChild(ymin)
  95.         # xmax ymax 就是标记框 右下角的坐标
  96.         xmax = doc.createElement('xmax')
  97.         xmax.appendChild(doc.createTextNode(str(bbox[2])))
  98.         bndbox.appendChild(xmax)
  99.         ymax = doc.createElement('ymax')
  100.         ymax.appendChild(doc.createTextNode(str(bbox[3])))
  101.         bndbox.appendChild(ymax)

  102.     with open(xmlpath, 'w') as f:
  103.         f.write(doc.toprettyxml(indent=''))


  104. # 转换数据集(Kitti---> VOC)

  105. trainval = open('/home/dfy888/DataSets/Kitti_voc/ImageSets/Main/trainval.txt', 'w')
  106. train = open('/home/dfy888/DataSets/Kitti_voc/ImageSets/Main/train.txt', 'w')
  107. val = open('/home/dfy888/DataSets/Kitti_voc/ImageSets/Main/val.txt', 'w')
  108. test = open('/home/dfy888/DataSets/Kitti_voc/ImageSets/Main/test.txt', 'w')

  109. index = 0
  110. # 7481
  111. for file_path in list_anno_files:
  112.     with open(file_path) as f:
  113.         # 每一个标注文件txt格式的
  114.         anno_infos = f.readlines()
  115.         # print anno_infos
  116.         bboxes = []
  117.         typename = []
  118.         for anno_item in anno_infos:
  119.             # 对每一行的信息进行解析
  120.             anno_item_infos = anno_item.split()
  121.             if anno_item_infos[0] == 'Misc' or anno_item_infos[0] == 'DontCare':
  122.                 # 将杂项与不关心的过滤掉 模型训练更容易一些
  123.                 continue
  124.             else:
  125.                 typename.append(anno_item_infos[0])
  126.                 bbox = (int(float(anno_item_infos[4])), int(float(anno_item_infos[5])),
  127.                         int(float(anno_item_infos[6])), int(float(anno_item_infos[7])))
  128.                 bboxes.append(bbox)
  129.                 pass

  130.         filename = file_path.split('/')[-1].replace('txt', 'png')
  131.         xmlpath = '/home/dfy888/DataSets/Kitti_voc/Annotations/' + filename.replace('png', 'xml')

  132.         img_path = '/home/dfy888/DataSets/Kitti_voc/JPEGImages/' + filename
  133.         saveimg = cv2.imread(img_path)
  134.         writexml(filename, saveimg, bboxes, xmlpath, typename)

  135.         # :param img_set_type: trainval or val or test or train
  136.         # trainval 90%  test 10%
  137.         # train 70%  val 20%
  138.         if index > len(list_anno_files) * 0.9:
  139.             test.write(filename.replace('.png', '\n'))
  140.         else:
  141.             trainval.write(filename.replace('.png', '\n'))
  142.             if index > len(list_anno_files) * 0.7:
  143.                 val.write(filename.replace('.png', '\n'))
  144.             else:
  145.                 train.write(filename.replace('.png', '\n'))

  146.         print '正在处理:' + str(index)
  147.         index += 1

  148. train.close()
  149. trainval.close()
  150. test.close()
  151. val.close()

复制代码

Faster RCNN检测模型的环境搭建:
https://github.com/rbgirshick/py-faster-rcnn
git clone --recursive https://github.com/rbgirshick/py-faster-rcnn.git
加入--recursive保证项目里面的caffe-fast-rcnn的下载 当然这里也有一个坑 就是之前的caffe版本过低
具体解决请看:http://www.ai111.vip/thread-788-1-1.html

Faster RCNN检测模型的测试:python tools/demo_detector_dfy.py
  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/9/8 上午10:48'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'demo_detector_dfy.py'
  6. """
  7. Demo script showing detections in sample images.

  8. 东方修改的:利用faster_rcnn vgg16的预训练模型进行目标检测demo
  9. """

  10. import _init_paths
  11. from fast_rcnn.config import cfg
  12. from nms.gpu_nms import gpu_nms
  13. from nms.cpu_nms import cpu_nms
  14. import time
  15. import matplotlib.pyplot as plt
  16. import numpy as np
  17. import caffe
  18. import os
  19. import cv2

  20. # pascal voc 共21类 (含背景)
  21. CLASSES = ('__background__',
  22.            'aeroplane', 'bicycle', 'bird', 'boat',
  23.            'bottle', 'bus', 'car', 'cat', 'chair',
  24.            'cow', 'diningtable', 'dog', 'horse',
  25.            'motorbike', 'person', 'pottedplant',
  26.            'sheep', 'sofa', 'train', 'tvmonitor')


  27. class Timer(object):
  28.     """A simple timer."""

  29.     def __init__(self):
  30.         self.total_time = 0.
  31.         self.calls = 0
  32.         self.start_time = 0.
  33.         self.diff = 0.
  34.         self.average_time = 0.

  35.     def tic(self):
  36.         # using time.time instead of time.clock because time time.clock
  37.         # does not normalize for multithreading
  38.         self.start_time = time.time()

  39.     def toc(self, average=True):
  40.         self.diff = time.time() - self.start_time
  41.         self.total_time += self.diff
  42.         self.calls += 1
  43.         self.average_time = self.total_time / self.calls
  44.         if average:
  45.             return self.average_time
  46.         else:
  47.             return self.diff


  48. def nms(dets, thresh, force_cpu=False):
  49.     """Dispatch to either CPU or GPU NMS implementations."""

  50.     if dets.shape[0] == 0:
  51.         return []
  52.     if cfg.USE_GPU_NMS and not force_cpu:
  53.         return gpu_nms(dets, thresh, device_id=cfg.GPU_ID)
  54.     else:
  55.         return cpu_nms(dets, thresh)


  56. def bbox_transform_inv(boxes, deltas):
  57.     if boxes.shape[0] == 0:
  58.         return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)

  59.     boxes = boxes.astype(deltas.dtype, copy=False)

  60.     widths = boxes[:, 2] - boxes[:, 0] + 1.0
  61.     heights = boxes[:, 3] - boxes[:, 1] + 1.0
  62.     ctr_x = boxes[:, 0] + 0.5 * widths
  63.     ctr_y = boxes[:, 1] + 0.5 * heights

  64.     dx = deltas[:, 0::4]
  65.     dy = deltas[:, 1::4]
  66.     dw = deltas[:, 2::4]
  67.     dh = deltas[:, 3::4]

  68.     pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
  69.     pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
  70.     pred_w = np.exp(dw) * widths[:, np.newaxis]
  71.     pred_h = np.exp(dh) * heights[:, np.newaxis]

  72.     pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype)
  73.     # x1
  74.     pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
  75.     # y1
  76.     pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
  77.     # x2
  78.     pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
  79.     # y2
  80.     pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h

  81.     return pred_boxes


  82. def clip_boxes(boxes, im_shape):
  83.     """
  84.     Clip boxes to image boundaries.
  85.     """

  86.     # x1 >= 0
  87.     boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
  88.     # y1 >= 0
  89.     boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
  90.     # x2 < im_shape[1]
  91.     boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
  92.     # y2 < im_shape[0]
  93.     boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
  94.     return boxes


  95. def im_list_to_blob(ims):
  96.     """Convert a list of images into a network input.
  97.     ims:一个列表 里面都是图片(cv2.resize之后的放进去的)

  98.     Assumes images are already prepared (means subtracted, BGR order, ...).
  99.     """
  100.     # max_shape :能满足要求或条件的最大的尺寸
  101.     max_shape = np.array([im.shape for im in ims]).max(axis=0)
  102.     num_images = len(ims)

  103.     # 初始化一个空的4维矩阵 shape:(N H W 3)
  104.     blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
  105.                     dtype=np.float32)

  106.     for i in xrange(num_images):
  107.         im = ims[i]
  108.         # 把实际的图片数据赋值给空的4维矩阵
  109.         blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
  110.     # Move channels (axis 3) to axis 1
  111.     # Axis order will become: (batch elem, channel, height, width)
  112.     channel_swap = (0, 3, 1, 2)
  113.     # 将通道数放前面来
  114.     blob = blob.transpose(channel_swap)
  115.     return blob


  116. def vis_detections(im, class_name, dets, thresh=0.5):
  117.     """Draw detected bounding boxes."""
  118.     inds = np.where(dets[:, -1] >= thresh)[0]
  119.     if len(inds) == 0:
  120.         return

  121.     im = im[:, :, (2, 1, 0)]
  122.     fig, ax = plt.subplots(figsize=(12, 12))
  123.     ax.imshow(im, aspect='equal')
  124.     for i in inds:
  125.         bbox = dets[i, :4]
  126.         score = dets[i, -1]

  127.         ax.add_patch(
  128.             plt.Rectangle((bbox[0], bbox[1]),
  129.                           bbox[2] - bbox[0],
  130.                           bbox[3] - bbox[1], fill=False,
  131.                           edgecolor='red', linewidth=3.5)
  132.         )
  133.         ax.text(bbox[0], bbox[1] - 2,
  134.                 '{:s} {:.3f}'.format(class_name, score),
  135.                 bbox=dict(facecolor='blue', alpha=0.5),
  136.                 fontsize=14, color='white')

  137.     ax.set_title(('{} detections with '
  138.                   'p({} | box) >= {:.1f}').format(class_name, class_name,
  139.                                                   thresh),
  140.                  fontsize=14)
  141.     plt.axis('off')
  142.     plt.tight_layout()
  143.     plt.draw()


  144. def _get_image_blob(im):
  145.     """Converts an image into a network input.

  146.     Arguments:
  147.         im (ndarray): a color image in BGR order

  148.     Returns:
  149.         blob (ndarray): a data blob holding an image pyramid
  150.         im_scale_factors (list): list of image scales (relative to im) used
  151.             in the image pyramid
  152.     """
  153.     im_orig = im.astype(np.float32, copy=True)
  154.     # 图片减掉均值操作
  155.     im_orig -= cfg.PIXEL_MEANS
  156.     # 原始图片的shape
  157.     im_shape = im_orig.shape
  158.     im_size_min = np.min(im_shape[0:2])
  159.     im_size_max = np.max(im_shape[0:2])

  160.     processed_ims = []
  161.     im_scale_factors = []
  162.     # cfg.TEST.SCALES = (600,)
  163.     for target_size in cfg.TEST.SCALES:
  164.         im_scale = float(target_size) / float(im_size_min)
  165.         # Prevent the biggest axis from being more than MAX_SIZE
  166.         # cfg.TEST.MAX_SIZE = 1000
  167.         if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
  168.             # 为了保证图片缩放后最大的尺寸(不论宽高)不能比预先配置的最大尺寸还要大
  169.             im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
  170.         # 到此就得到满足条件的 图片缩放比例im_scale
  171.         im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
  172.                         interpolation=cv2.INTER_LINEAR)
  173.         im_scale_factors.append(im_scale)
  174.         processed_ims.append(im)

  175.     # Create a blob to hold the input images
  176.     blob = im_list_to_blob(processed_ims)

  177.     return blob, np.array(im_scale_factors)


  178. def _project_im_rois(im_rois, scales):
  179.     """Project image RoIs into the image pyramid built by _get_image_blob.

  180.     Arguments:
  181.         im_rois (ndarray): R x 4 matrix of RoIs in original image coordinates
  182.         scales (list): scale factors as returned by _get_image_blob

  183.     Returns:
  184.         rois (ndarray): R x 4 matrix of projected RoI coordinates
  185.         levels (list): image pyramid levels used by each projected RoI
  186.     """
  187.     im_rois = im_rois.astype(np.float, copy=False)

  188.     if len(scales) > 1:
  189.         widths = im_rois[:, 2] - im_rois[:, 0] + 1
  190.         heights = im_rois[:, 3] - im_rois[:, 1] + 1

  191.         areas = widths * heights
  192.         scaled_areas = areas[:, np.newaxis] * (scales[np.newaxis, :] ** 2)
  193.         diff_areas = np.abs(scaled_areas - 224 * 224)
  194.         levels = diff_areas.argmin(axis=1)[:, np.newaxis]
  195.     else:
  196.         levels = np.zeros((im_rois.shape[0], 1), dtype=np.int)

  197.     rois = im_rois * scales[levels]

  198.     return rois, levels


  199. def _get_rois_blob(im_rois, im_scale_factors):
  200.     """Converts RoIs into network inputs.

  201.     Arguments:
  202.         im_rois (ndarray): R x 4 matrix of RoIs in original image coordinates
  203.         im_scale_factors (list): scale factors as returned by _get_image_blob

  204.     Returns:
  205.         blob (ndarray): R x 5 matrix of RoIs in the image pyramid
  206.     """
  207.     rois, levels = _project_im_rois(im_rois, im_scale_factors)
  208.     rois_blob = np.hstack((levels, rois))
  209.     return rois_blob.astype(np.float32, copy=False)


  210. def _get_blobs(im, rois):
  211.     """Convert an image and RoIs within that image into network inputs."""
  212.     blobs = {'data': None, 'rois': None}

  213.     blobs['data'], im_scale_factors = _get_image_blob(im)
  214.     if not cfg.TEST.HAS_RPN:
  215.         print >> dfy_log_file_writer, '没有RPN网络的情况 rois不为None'
  216.         blobs['rois'] = _get_rois_blob(rois, im_scale_factors)

  217.     print >> dfy_log_file_writer, 'blobs:', blobs.keys(), blobs['data'].shape, im_scale_factors
  218.     return blobs, im_scale_factors


  219. def im_detect(net, im, boxes=None):
  220.     """Detect object classes in an image given object proposals.

  221.     Arguments:
  222.         net (caffe.Net): Fast R-CNN network to use
  223.         im (ndarray): color image to test (in BGR order)
  224.         boxes (ndarray): R x 4 array of object proposals or None (for RPN)

  225.     Returns:
  226.         scores (ndarray): R x K array of object class scores (K includes
  227.             background as object category 0)
  228.         boxes (ndarray): R x (4*K) array of predicted bounding boxes
  229.     """

  230.     # 将原始图片进行缩放 宽高同比例缩放
  231.     blobs, im_scales = _get_blobs(im, boxes)

  232.     # When mapping from image ROIs to feature map ROIs, there's some aliasing
  233.     # (some distinct image ROIs get mapped to the same feature ROI).
  234.     # Here, we identify duplicate feature ROIs, so we only compute features
  235.     # on the unique subset.

  236.     if cfg.TEST.HAS_RPN:
  237.         im_blob = blobs['data']
  238.         blobs['im_info'] = np.array(
  239.             [[im_blob.shape[2], im_blob.shape[3], im_scales[0]]],
  240.             dtype=np.float32)
  241.         print >> dfy_log_file_writer, '有RPN网络:', blobs.keys(), blobs['im_info']

  242.     # reshape network inputs
  243.     print >> dfy_log_file_writer, 'reshape操作:', blobs['data'].shape
  244.     net.blobs['data'].reshape(*(blobs['data'].shape))

  245.     if cfg.TEST.HAS_RPN:
  246.         print >> dfy_log_file_writer, 'reshape操作:', blobs['im_info'].shape
  247.         net.blobs['im_info'].reshape(*(blobs['im_info'].shape))
  248.     else:
  249.         net.blobs['rois'].reshape(*(blobs['rois'].shape))

  250.     # do forward 开始前向计算
  251.     forward_kwargs = {'data': blobs['data'].astype(np.float32, copy=False)}

  252.     if cfg.TEST.HAS_RPN:
  253.         forward_kwargs['im_info'] = blobs['im_info'].astype(np.float32, copy=False)
  254.     else:
  255.         forward_kwargs['rois'] = blobs['rois'].astype(np.float32, copy=False)

  256.     # 前向计算
  257.     blobs_out = net.forward(**forward_kwargs)

  258.     # https://blog.csdn.net/tina_ttl/article/details/51033660 caffe中如何可视化cnn各层的输出

  259.     for layer_name, blob in net.blobs.iteritems():
  260.         print >> dfy_log_file_writer, '层名+shape:' + layer_name + '\t' + str(blob.data.shape)

  261.     for layer_name, param in net.params.iteritems():
  262.         print >> dfy_log_file_writer, '层名+网络W与b:' + layer_name + '\t' + str(param[0].data.shape), str(param[1].data.shape)

  263.     print >> dfy_log_file_writer, '前向计算结果blobs_out:\n', blobs_out.keys(), \
  264.         blobs_out['bbox_pred'].shape, \
  265.         blobs_out['cls_prob'].shape

  266.     if cfg.TEST.HAS_RPN:
  267.         assert len(im_scales) == 1, "Only single-image batch implemented"
  268.         rois = net.blobs['rois'].data.copy()
  269.         print >> dfy_log_file_writer, '\nrois:', rois, rois.shape
  270.         # unscale back to raw image space
  271.         boxes = rois[:, 1:5] / im_scales[0]
  272.         print >> dfy_log_file_writer, '\nboxes:', boxes, boxes.shape

  273.     # use softmax estimated probabilities
  274.     scores = blobs_out['cls_prob']

  275.     if cfg.TEST.BBOX_REG:
  276.         # True
  277.         # Apply bounding-box regression deltas
  278.         box_deltas = blobs_out['bbox_pred']

  279.         pred_boxes = bbox_transform_inv(boxes, box_deltas)
  280.         pred_boxes = clip_boxes(pred_boxes, im.shape)

  281.     return scores, pred_boxes


  282. def demo(net, image_name):
  283.     """Detect object classes in an image using pre-computed object proposals."""

  284.     # Load the demo image
  285.     im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
  286.     print >> dfy_log_file_writer, 'start检测图片路径:', im_file
  287.     im = cv2.imread(im_file)
  288.     print >> dfy_log_file_writer, '图片原始大小HWC:', im.shape

  289.     # Detect all object classes and regress object bounds
  290.     timer = Timer()
  291.     timer.tic()
  292.     scores, boxes = im_detect(net, im)
  293.     timer.toc()
  294.     # ('scores.shape:', (300, 21))
  295.     print('scores.shape:', scores.shape)
  296.     # ('boxes.shape:', (300, 84))
  297.     print('boxes.shape:', boxes.shape)
  298.     print ('Detection took {:.3f}s for '
  299.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])

  300.     # Visualize detections for each class
  301.     # 置信度阈值
  302.     CONF_THRESH = 0.95
  303.     # nms阈值
  304.     NMS_THRESH = 0.6

  305.     for cls_ind, cls in enumerate(CLASSES[1:]):
  306.         cls_ind += 1  # because we skipped background
  307.         cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
  308.         cls_scores = scores[:, cls_ind]
  309.         dets = np.hstack((cls_boxes,
  310.                           cls_scores[:, np.newaxis])).astype(np.float32)
  311.         # 先使用nms的阈值
  312.         keep = nms(dets, NMS_THRESH)
  313.         dets = dets[keep, :]
  314.         # 可视化的时候再使用 置信度阈值
  315.         vis_detections(im, cls, dets, thresh=CONF_THRESH)


  316. if __name__ == '__main__':
  317.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals

  318.     # 相对应deploy.prototxt文件  caffemodel deploy.prototxt
  319.     prototxt = 'models/pascal_voc/VGG16/faster_rcnn_alt_opt/faster_rcnn_test.pt'

  320.     caffemodel = 'data/faster_rcnn_models/VGG16_faster_rcnn_final.caffemodel'

  321.     caffe.set_device(0)
  322.     caffe.set_mode_gpu()

  323.     # dfy log打印结果文件
  324.     # https://blog.csdn.net/jiongnima/article/details/80016683
  325.     # 详细的Faster R-CNN源码解析之ROI-Pooling逐行代码解析
  326.     dfy_log_file = "tools/demo_detector_dfy.log"
  327.     dfy_log_file_writer = open(dfy_log_file, 'w')

  328.     net = caffe.Net(prototxt, caffemodel, caffe.TEST)

  329.     # im_names = ['001763.jpg', '004545.jpg']
  330.     im_names = ['test01.png']
  331.     for im_name in im_names:
  332.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  333.         print 'Demo for data/demo/{}'.format(im_name)
  334.         demo(net, im_name)

  335.     plt.show()
  336.     dfy_log_file_writer.close()
复制代码

faster rcnn网络结构详解http://www.ai111.vip/thread-800-1-1.html


使用Faster RCNN算法训练Kitti数据集
1、修改train网络与test网络的结构文件
之前是用在Pascal voc数据集上(共21类) 现在需要用到kitti数据集(共8类)
models/pascal_voc/VGG_CNN_M_1024/faster_rcnn_end2end/train.prototxt
修改1:name: 'input-data'层 type: 'Python'  num_classes=8
修改2:name: 'roi-data'层  type: 'Python'  num_classes=8
修改3:name: "cls_score"层  type: "InnerProduct"  num_output: 8
修改4:name: "bbox_pred"层  type: "InnerProduct"  num_output: 32=8*4
models/pascal_voc/VGG_CNN_M_1024/faster_rcnn_end2end/test.prototxt
修改1:name: "cls_score"层  type: "InnerProduct"  num_output: 8
修改2:name: "bbox_pred"层  type: "InnerProduct"  num_output: 32=8*4


2、修改lib/datasets/pascal_voc.py中的类别信息
        # 由 21类 改为 8 类 全部都是小写 注意
        self._classes = ('__background__',  # always index 0
                         'person_sitting', 'truck', 'van', 'pedestrian',
                         'cyclist', 'tram', 'car')


3、修改lib/datasets/pascal_voc.py中的数据路径
self._data_path = '/home/dfy888/DataSets/Kitti_voc'
self._image_ext = '.png'
并注释掉下面的代码
# assert os.path.exists(self._devkit_path), \
#     'VOCdevkit path does not exist: {}'.format(self._devkit_path)
搜索_devkit_path替换为_data_path的相应路径即可
修改里面的一个函数_load_pascal_annotation(x1 y1 不能减掉1)
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text)
            y1 = float(bbox.find('ymin').text)
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
4、去掉预训练模型
修改tools/train_net.py中
train_net(args.solver, roidb, output_dir,
              pretrained_model=None,
              max_iters=args.max_iters)
5、遇到的bug与解决方案
http://www.ai111.vip/thread-790-1-1.html
http://www.ai111.vip/thread-789-1-1.html
http://www.ai111.vip/thread-788-1-1.html
http://www.ai111.vip/thread-791-1-1.html
http://www.ai111.vip/thread-792-1-1.html
6、开始训练python tools/train_net.py --gpu 0
7、重新训练时需要删除data/cache
8、开始测试python tools/test_net.py --gpu 0

模型优化思路(实际工程项目中不会去修改整体框架)
1、增加训练次数 比如50万次 确保网络收敛
2、修改主干网络CNN 在train.prototxt(重点)
3、修改在train.prototxt中定义的Python层
一般都在lib目录下 比如:roi_data_layer nms等
4、优化输入的数据(重点)
5、精调网络的超参 lib/fast_rcnn/config.py 很多阈值等











Kitti需下载的.png
faster rcnn模型测试dfy.png
faster rcnn train网络结构01.png
faster rcnn train网络结构02.png
faster rcnn test网络结构01.png
faster rcnn test网络结构02.png
Faster RCNN算法优化方向.png
Light-Head RCNN.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备18018285号-1 )

GMT+8, 2019-9-23 02:15 , Processed in 0.211946 second(s), 22 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表