东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 6219|回复: 3
打印 上一主题 下一主题

[课堂笔记] 【项目】02、基于Keras训练代码框架实现

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2019-10-21 09:15:44 | 只看该作者 |只看大图 回帖奖励 |倒序浏览 |阅读模式


1、模型构建 def get_model
2、数据生成器 def batch_generator
3、开始训练网络 model.fit_generator
4、可视化 callbacks.TensorBoard   plt.savefig('train_test_loss.jpg')
5、图像预处理 数据增强等都没有实现


根据车载摄像头的画面,自动判断如何打方向盘?使用端到端(end-to-end)的深度神经网络CNN


Data Generator:无需预先生成所有图像增强后的图像,会占用太多的硬盘空间,会增加读取硬盘文件所需的时间


数据driving_log.csv文件中右 center left right steering角度   throttle油门 brake刹车  speed速度
我们这里只用 center 和 steering角度
csv文件的读取:
data_path = 'F:\\AI_Study_dfy\\项目:自动驾驶之方向盘转动角度预测data\\'
    with open(data_path + 'driving_log.csv', 'r') as csvfile:
        file_reader = csv.reader(csvfile, delimiter=',')
        log = []
        for row in file_reader:
            log.append(row)
    log = np.array(log)
    # 二维矩阵里面都是字符串
    print(log.shape)
    print(log.ndim)
    print(log.dtype)
    print(log[:5, :])


遇到问题:OpenCV无法读取中文路径
解决:http://www.ai111.vip/thread-861-1-1.html


  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/10/18 11:49'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'train_dfy'

  6. import tensorflow as tf
  7. import numpy as np
  8. import torch as T
  9. from keras.layers import Conv2D, MaxPooling2D, Flatten, PReLU
  10. from keras.layers.core import Dense, Dropout, Activation
  11. from keras.optimizers import SGD, Adam
  12. from keras.models import Model, Sequential
  13. from keras import backend as K
  14. from keras.regularizers import l2
  15. import os.path
  16. import cv2
  17. import skimage.io as iio
  18. import csv
  19. import glob
  20. import pickle
  21. from sklearn.utils import shuffle
  22. from sklearn.model_selection import cross_validate
  23. from sklearn.model_selection import train_test_split
  24. import json
  25. from keras import callbacks
  26. import math
  27. import matplotlib.pyplot as plt

  28. SEED = 666

  29. print(tf.__version__)
  30. print(T.__version__)


  31. def get_model(shape):
  32.     """
  33.     预测方向盘角度,以图像为输入,预测方向盘的转动角度
  34.     :param shape: 图像尺寸 (128, 128, 3)  NHC
  35.     :return:
  36.     """
  37.     model = Sequential(name='dfy_seq_model')
  38.     # 第一层需要指定input_shape 后面不需要
  39.     model.add(Conv2D(filters=24, kernel_size=(5, 5), strides=(2, 2),
  40.                      padding='valid', data_format='channels_last', activation='relu', input_shape=shape))
  41.     model.add(Conv2D(filters=36, kernel_size=(5, 5), strides=(2, 2),
  42.                      padding='valid', data_format='channels_last', activation='relu'))
  43.     model.add(Conv2D(filters=48, kernel_size=(5, 5), strides=(2, 2),
  44.                      padding='valid', data_format='channels_last', activation='relu'))

  45.     model.add(Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1),
  46.                      padding='valid', data_format='channels_last', activation='relu'))
  47.     model.add(Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1),
  48.                      padding='valid', data_format='channels_last', activation='relu'))

  49.     model.add(Flatten(data_format='channels_last'))
  50.     model.add(Dense(units=1164, activation='relu'))
  51.     model.add(Dense(units=100, activation='relu'))
  52.     model.add(Dense(units=50, activation='relu'))
  53.     model.add(Dense(units=10, activation='relu'))
  54.     # 由于输出的角度是 (-pi/2, pi/2) 要选择好的激活函数
  55.     model.add(Dense(units=1, activation='linear'))
  56.     # compile: 1、指定优化器  2、损失函数
  57.     model.compile(optimizer=Adam(learning_rate=0.001), loss='mean_squared_error')

  58.     return model


  59. # 开始数据增强(基于现有样本数据产生新的更多的训练数据)
  60. def random_brightness(img, degree):
  61.     """
  62.     随机调整输入图像的亮度,调整强度于0.1(变黑)和1(无变化)之间
  63.     :param img: 输入图像
  64.     :param degree: 输入图像对应的转动角度
  65.     :return:
  66.     """

  67.     return (img, degree)


  68. def horizontal_flip(img, degree):
  69.     """
  70.     按照50%的概率水平翻转图像
  71.     :param img: 输入图像
  72.     :param degree: 输入图像对应的转动角度
  73.     :return:
  74.     """
  75.     pass
  76.     return (img, degree)


  77. def left_right_random_swap(img_address, degree, degree_corr=1.0 / 4):
  78.     """
  79.     随机从左、中、右图像中选择一张图像,并相应调整转动的角度
  80.     :param img_address: 中间图像的文件路径
  81.     :param degree: 中间图像对应的方向盘转动角度
  82.     :param degree_corr: 方向盘转动角度调整的值
  83.     :return:
  84.     """
  85.     return (img_address, degree)


  86. def discard_zero_steering(degrees, rate):
  87.     """
  88.     从角度为0的index中随机选择部分index返回
  89.     :param degrees: 输入的角度值
  90.     :param rate: 丢弃率 rate=0.8 表示80%的index返回
  91.     :return:
  92.     """
  93.     return degrees


  94. def image_transformation(img_address, degree, data_dir):
  95.     # img_address, degree = left_right_random_swap(img_address, degree)
  96.     # opencv 读出来的图像是 bgr的
  97.     # print('开始读的图片地址(路径有中文):', data_dir+img_address)
  98.     img = iio.imread(data_dir + img_address)
  99.     # cvt convert 颜色空间的转换
  100.     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  101.     # img, degree = random_brightness(img, degree)
  102.     # img, degree = horizontal_flip(img, degree)
  103.     return (img, degree)


  104. def batch_generator(x, y, batch_size, shape, training=True, data_dir='data/', discard_rate=0.95):
  105.     """
  106.     产生批处理的数据Generator 高效读取数据
  107.     Data Generator 无需预先生成所有图像增强后的图像,会占用太多的硬盘空间与增加读取硬盘文件所需的时间
  108.     :param x: 图像文件路径list
  109.     :param y: 方向盘的角度
  110.     :param batch_size:
  111.     :param shape: 输入图像的尺寸(HWC)
  112.     :param training: True时产生训练数据 False时产生validation数据
  113.     :param data_dir: 数据目录,包含一个IMG文件夹
  114.     :param discard_rate: 随机丢弃角度=0的训练数据的比率
  115.     :return:
  116.     """
  117.     if training:
  118.         x, y = shuffle(x, y)
  119.         rand_zero_idx = discard_zero_steering(y, rate=discard_rate)
  120.         new_x = np.delete(x, rand_zero_idx, axis=0)
  121.         new_y = np.delete(y, rand_zero_idx, axis=0)
  122.     else:
  123.         new_x = x
  124.         new_y = y
  125.     offset = 0
  126.     while True:
  127.         X = np.empty(shape=(batch_size, *shape))
  128.         Y = np.empty(shape=(batch_size, 1))
  129.         for example in range(batch_size):
  130.             img_address, img_steering = new_x[example + offset], new_y[example + offset]
  131.             if training:
  132.                 img, img_steering = image_transformation(img_address, img_steering, data_dir)
  133.             else:
  134.                 img = iio.imread(data_dir + img_address)
  135.                 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  136.             # 先截取图像  后resize 再 归一化到[-0.5, 0.5]
  137.             # [NHWC] 四维
  138.             X[example, :, :, :] = cv2.resize(img[80:140, 0:320], (shape[0], shape[1])) / 255 - 0.5
  139.             Y[example] = img_steering

  140.             if (example + 1) + offset > len(new_y) - 1:
  141.                 # 达到了原来数据的尾部,从头开始
  142.                 x, y = shuffle(x, y)
  143.                 rand_zero_idx = discard_zero_steering(y, rate=discard_rate)
  144.                 new_x = x
  145.                 new_y = y
  146.                 new_x = np.delete(new_x, rand_zero_idx, axis=0)
  147.                 new_y = np.delete(new_y, rand_zero_idx, axis=0)
  148.                 offset = 0
  149.         yield (X, Y) # 类似return 但有区别
  150.         offset = offset + batch_size


  151. if __name__ == '__main__':
  152.     data_path = 'F:\\AI_Study_dfy\\项目:自动驾驶之方向盘转动角度预测data\\'
  153.     with open(data_path + 'driving_log.csv', 'r') as csvfile:
  154.         file_reader = csv.reader(csvfile, delimiter=',')
  155.         log = []
  156.         for row in file_reader:
  157.             log.append(row)
  158.     log = np.array(log)
  159.     # 二维矩阵里面都是字符串
  160.     print(log.shape)
  161.     print(log.ndim)
  162.     print(log.dtype)
  163.     print(log[:5, :])
  164.     # 去掉第一行 表头数据
  165.     log = log[1:, :]
  166.     ls_imgs = glob.glob(data_path + 'IMG/*.jpg')
  167.     print('一共有%d张图片(包括中间、左边、右边)' % len(ls_imgs))
  168.     assert len(ls_imgs) == len(log) * 3, 'number of images does not match!'
  169.     # 使用20%的数据作为测试数据集
  170.     test_ratio = 0.2
  171.     shape = (128, 128, 3)
  172.     batch_size = 20
  173.     # 所有的数据跑多少轮?epoch
  174.     nb_epoch = 30

  175.     # 中间摄像头的图片路径 str
  176.     x_ = log[:, 0]
  177.     # steering  str---> float
  178.     y_ = log[:, 3].astype(float)
  179.     x_, y_ = shuffle(x_, y_)
  180.     X_train, X_test, y_train, y_test = train_test_split(x_, y_, random_state=SEED, test_size=test_ratio)
  181.     print('训练集大小:{}, 测试集大小:{}'.format(len(X_train), len(X_test)))

  182.     steps_per_epoch = 20
  183.     # 使得test数据集大小为batch_size的整数倍
  184.     nb_test_samples = len(y_test) - len(y_test) % batch_size
  185.     print('nb_test_samples size:', nb_test_samples)

  186.     model = get_model(shape)
  187.     # 生成模型结构汇总 挺好!
  188.     print(model.summary())
  189.     # 根据test loss保存最优模型
  190.     save_best = callbacks.ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', verbose=1,
  191.                                           save_best_only=True, mode='min')
  192.     # 如果训练连续patience=15(向后看多少步)val_loss did not improve(网络不收敛),提前结束训练
  193.     # if (last_loss - current_loss) > min_delta 才算网络是在优化 否则表示loss没有在下降
  194.     early_stop = callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-5, patience=15, verbose=0, mode='auto')

  195.     tb_callback = callbacks.TensorBoard(log_dir='./Graph', write_graph=True)

  196.     callback_lists = [early_stop, save_best, tb_callback]

  197.     # 开始训练网络
  198.     history = model.fit_generator(
  199.         generator=batch_generator(X_train, y_train, batch_size, shape, training=True, data_dir=data_path),
  200.         steps_per_epoch=steps_per_epoch, validation_steps=nb_test_samples // batch_size,
  201.         validation_data=batch_generator(X_test, y_test, batch_size, shape, training=False, data_dir=data_path),
  202.         epochs=nb_epoch, verbose=1, callbacks=callback_lists)

  203.     with open('./trainHistoryDict.pickle', 'wb') as file_pickle:
  204.         pickle.dump(history.history, file_pickle)

  205.     plt.plot(history.history['loss'])
  206.     plt.plot(history.history['val_loss'])
  207.     plt.title('model train vs test loss')
  208.     plt.xlabel('epoch')
  209.     plt.ylabel('loss')
  210.     plt.legend(['train', 'test'], loc='upper right')
  211.     plt.savefig('train_test_loss.jpg')
  212.     # 模型保存
  213.     with open('model.json', 'w') as f:
  214.         f.write(model.to_json())
  215.     model.save('model.h5')
  216.     print('Done!')
复制代码



东方老师AI官网:http://www.ai111.vip
有任何问题可联系东方老师微信:dfy_88888
【微信二维码图片】

00.png (74.07 KB, 下载次数: 511)

00.png

driving_log.csv

1.12 MB, 阅读权限: 188, 下载次数: 4

让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

117

帖子

258

积分

中级会员

Rank: 3Rank: 3

积分
258
QQ
沙发
发表于 2020-2-3 15:41:12 | 只看该作者
谢谢老师提供的资料。
回复

使用道具 举报

0

主题

96

帖子

204

积分

中级会员

Rank: 3Rank: 3

积分
204
板凳
发表于 2022-3-18 00:35:46 | 只看该作者
dsffdsfsdfdsfds
回复

使用道具 举报

0

主题

96

帖子

204

积分

中级会员

Rank: 3Rank: 3

积分
204
地板
发表于 2022-3-18 00:36:06 | 只看该作者
dsfsfsdfdsfsdfsdfds
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-20 07:23 , Processed in 0.175103 second(s), 21 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表