东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 5225|回复: 1
打印 上一主题 下一主题

[课堂笔记] 03、模型训练的代码调整一下以及初始训练结果

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2019-11-7 09:21:57 | 只看该作者 |只看大图 回帖奖励 |倒序浏览 |阅读模式
03、模型训练的代码调整一下以及初始训练结果


dict_keys(['val_loss', 'val_mae', 'loss', 'mae'])
    val_loss   val_mae      loss       mae
0   0.009134  0.070623  0.013413  0.073236
1   0.011031  0.067170  0.010952  0.067101
2   0.005622  0.067855  0.010583  0.067400
3   0.005378  0.064024  0.010231  0.066075
4   0.004834  0.062514  0.009657  0.063892
5   0.005470  0.064100  0.009387  0.062929
6   0.003604  0.061187  0.009363  0.063271
7   0.004759  0.067398  0.008875  0.061294
8   0.008716  0.061494  0.008494  0.060054
9   0.009699  0.064915  0.008158  0.058782
10  0.010862  0.062815  0.007335  0.055527
11  0.004432  0.060743  0.007094  0.053928



  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/11/7 08:49'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'train_dfy'

  6. import tensorflow as tf
  7. import numpy as np
  8. import torch as T
  9. from keras.layers import Conv2D, MaxPooling2D, Flatten, PReLU
  10. from keras.layers.core import Dense, Dropout, Activation
  11. from keras.optimizers import SGD, Adam
  12. from keras.models import Model, Sequential
  13. from keras import backend as K
  14. from keras.regularizers import l2
  15. import os.path
  16. import cv2
  17. import skimage.io as iio
  18. import csv
  19. import glob
  20. import pickle
  21. from sklearn.utils import shuffle
  22. from sklearn.model_selection import cross_validate
  23. from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV
  24. import json
  25. from keras import callbacks
  26. import math
  27. import matplotlib.pyplot as plt
  28. from tensorflow import gfile
  29. import pandas as pd

  30. SEED = 666

  31. print(tf.__version__)
  32. print(T.__version__)


  33. def get_model(shape):
  34.     """
  35.     预测方向盘角度,以图像为输入,预测方向盘的转动角度
  36.     :param shape: 图像尺寸 (128, 128, 3)  NHC
  37.     :return:
  38.     """
  39.     model = Sequential(name='dfy_seq_model')
  40.     # 第一层需要指定input_shape 后面不需要
  41.     model.add(Conv2D(filters=24, kernel_size=(5, 5), strides=(2, 2),
  42.                      padding='valid', data_format='channels_last', activation='relu', input_shape=shape))
  43.     model.add(Conv2D(filters=36, kernel_size=(5, 5), strides=(2, 2),
  44.                      padding='valid', data_format='channels_last', activation='relu'))
  45.     model.add(Conv2D(filters=48, kernel_size=(5, 5), strides=(2, 2),
  46.                      padding='valid', data_format='channels_last', activation='relu'))

  47.     model.add(Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1),
  48.                      padding='valid', data_format='channels_last', activation='relu'))
  49.     model.add(Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1),
  50.                      padding='valid', data_format='channels_last', activation='relu'))

  51.     model.add(Flatten(data_format='channels_last'))
  52.     model.add(Dense(units=1164, activation='relu'))
  53.     model.add(Dense(units=100, activation='relu'))
  54.     model.add(Dense(units=50, activation='relu'))
  55.     model.add(Dense(units=10, activation='relu'))
  56.     # 由于输出的角度是 (-pi/2, pi/2) 要选择好的激活函数
  57.     model.add(Dense(units=1, activation='linear'))
  58.     # compile: 1、指定优化器  2、损失函数
  59.     model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', metrics=['mae'])

  60.     return model


  61. # 开始数据增强(基于现有样本数据产生新的更多的训练数据)
  62. def random_brightness(img, degree):
  63.     """
  64.     随机调整输入图像的亮度,调整强度于0.1(变黑)和1(无变化)之间
  65.     :param img: 输入图像
  66.     :param degree: 输入图像对应的转动角度
  67.     :return:
  68.     """

  69.     return (img, degree)


  70. def horizontal_flip(img, degree):
  71.     """
  72.     按照50%的概率水平翻转图像
  73.     :param img: 输入图像
  74.     :param degree: 输入图像对应的转动角度
  75.     :return:
  76.     """
  77.     pass
  78.     return (img, degree)


  79. def left_right_random_swap(img_address, degree, degree_corr=1.0 / 4):
  80.     """
  81.     随机从左、中、右图像中选择一张图像,并相应调整转动的角度
  82.     :param img_address: 中间图像的文件路径
  83.     :param degree: 中间图像对应的方向盘转动角度
  84.     :param degree_corr: 方向盘转动角度调整的值
  85.     :return:
  86.     """
  87.     return (img_address, degree)


  88. def discard_zero_steering(degrees, rate):
  89.     """
  90.     从角度为0的index中随机选择部分index返回
  91.     :param degrees: 输入的角度值
  92.     :param rate: 丢弃率 rate=0.8 表示80%的index返回
  93.     :return:
  94.     """
  95.     return degrees


  96. def image_transformation(img_address, degree, data_dir):
  97.     # img_address, degree = left_right_random_swap(img_address, degree)
  98.     # opencv 读出来的图像是 bgr的
  99.     # print('开始读的图片地址(路径有中文):', data_dir+img_address)
  100.     img = iio.imread(data_dir + img_address)
  101.     # cvt convert 颜色空间的转换
  102.     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  103.     # img, degree = random_brightness(img, degree)
  104.     # img, degree = horizontal_flip(img, degree)
  105.     return (img, degree)


  106. def batch_generator(x, y, batch_size, shape, training=True, data_dir='data/', discard_rate=0.95):
  107.     """
  108.     产生批处理的数据Generator 高效读取数据
  109.     Data Generator 无需预先生成所有图像增强后的图像,会占用太多的硬盘空间与增加读取硬盘文件所需的时间
  110.     :param x: 图像文件路径list
  111.     :param y: 方向盘的角度
  112.     :param batch_size:
  113.     :param shape: 输入图像的尺寸(HWC)
  114.     :param training: True时产生训练数据 False时产生validation数据
  115.     :param data_dir: 数据目录,包含一个IMG文件夹
  116.     :param discard_rate: 随机丢弃角度=0的训练数据的比率
  117.     :return:
  118.     """
  119.     if training:
  120.         x, y = shuffle(x, y)
  121.         rand_zero_idx = discard_zero_steering(y, rate=discard_rate)
  122.         new_x = np.delete(x, rand_zero_idx, axis=0)
  123.         new_y = np.delete(y, rand_zero_idx, axis=0)
  124.     else:
  125.         new_x = x
  126.         new_y = y
  127.     offset = 0
  128.     while True:
  129.         X = np.empty(shape=(batch_size, *shape))
  130.         Y = np.empty(shape=(batch_size, 1))
  131.         for example in range(batch_size):
  132.             img_address, img_steering = new_x[example + offset], new_y[example + offset]
  133.             if training:
  134.                 img, img_steering = image_transformation(img_address, img_steering, data_dir)
  135.             else:
  136.                 img = iio.imread(data_dir + img_address)
  137.                 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  138.             # 先截取图像  后resize 再 归一化到[-0.5, 0.5]
  139.             # [NHWC] 四维
  140.             X[example, :, :, :] = cv2.resize(img[80:140, 0:320], (shape[0], shape[1])) / 255 - 0.5
  141.             Y[example] = img_steering

  142.             if (example + 1) + offset > len(new_y) - 1:
  143.                 # 达到了原来数据的尾部,从头开始
  144.                 x, y = shuffle(x, y)
  145.                 rand_zero_idx = discard_zero_steering(y, rate=discard_rate)
  146.                 new_x = x
  147.                 new_y = y
  148.                 new_x = np.delete(new_x, rand_zero_idx, axis=0)
  149.                 new_y = np.delete(new_y, rand_zero_idx, axis=0)
  150.                 offset = 0
  151.         yield (X, Y)
  152.         offset = offset + batch_size


  153. def load_train_datasets(data_path):
  154.     with open(data_path + 'driving_log.csv', 'r') as csvfile:
  155.         file_reader = csv.reader(csvfile, delimiter=',')
  156.         log = []
  157.         for row in file_reader:
  158.             log.append(row)
  159.     log = np.array(log)
  160.     # 二维矩阵里面都是字符串 (8037, 7)
  161.     print(log.shape)
  162.     print(log.dtype)
  163.     print(log[:3, :])
  164.     # 去掉第一行 表头数据
  165.     log = log[1:, :]
  166.     ls_imgs = glob.glob(data_path + 'IMG/*.jpg')
  167.     # 一共有24108张图片 = 8036*3
  168.     print('一共有%d张图片(包括中间、左边、右边)' % len(ls_imgs))
  169.     assert len(ls_imgs) == len(log) * 3, 'number of images does not match!'
  170.     # 中间摄像头的图片路径 str
  171.     x_ = log[:, 0]
  172.     # steering  str---> float
  173.     y_ = log[:, 3].astype(float)
  174.     return x_, y_


  175. def plot_learning_curve(history):
  176.     # ValueError: arrays must all be same length
  177.     # 表格型数据 要求每一列的len一致 这里即:history.history字典里每个key对应的value长度一致
  178.     df_history = pd.DataFrame(data=history.history)
  179.     print(df_history)
  180.     # print(df_history.index)
  181.     print(df_history.columns)
  182.     # print(df_history.dtypes)
  183.     df_history.plot(figsize=(8, 5))
  184.     plt.grid(True)
  185.     # x就是DataFrame的索引
  186.     plt.ylim(0, 1.5)
  187.     plt.show()


  188. def train(model, input_shape, X_train, y_train, X_validation, y_validation, data_path):
  189.     batch_size = 32
  190.     # 使得validation数据集大小为batch_size的整数倍 训练集大小:6428, 验证集大小:1608
  191.     num_validation_samples = len(y_validation) - len(y_validation) % batch_size
  192.     # num_validation_samples size: 1600
  193.     print('num_validation_samples size:', num_validation_samples)
  194.     cb_save_best_model = callbacks.ModelCheckpoint(filepath='best_model.h5', monitor='val_loss',
  195.                                                    save_best_only=True, mode='min')
  196.     # 如果训练连续patience=15(向后看多少步)val_loss did not improve(网络不收敛),提前结束训练
  197.     # if (last_loss - current_loss) > min_delta 才算网络是在优化 否则表示loss没有在下降
  198.     cb_early_stop = callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5)
  199.     cb_tensorboard = callbacks.TensorBoard(log_dir='./TB_Graph')
  200.     callbacks_list = [cb_save_best_model, cb_early_stop, cb_tensorboard]
  201.     # 开始训练网络 训练集大小:6428, 验证集大小:1608
  202.     # steps_per_epoch = 6428 // batch_size =200
  203.     history = model.fit_generator(
  204.         generator=batch_generator(X_train, y_train, batch_size, input_shape, training=True, data_dir=data_path),
  205.         steps_per_epoch=200, validation_steps=num_validation_samples // batch_size,
  206.         validation_data=batch_generator(X_validation, y_validation, batch_size, input_shape, training=False, data_dir=data_path),
  207.         epochs=20, verbose=1, callbacks=callbacks_list)

  208.     # list all data in history
  209.     print(history.history.keys())

  210.     fig, ax_array = plt.subplots(1, 2)
  211.     ax1, ax2 = ax_array

  212.     ax1.set_title('model metrics')
  213.     ax1.plot(history.history['mae'])
  214.     ax1.plot(history.history['val_mae'])
  215.     ax1.set_xlabel('epoch')
  216.     ax1.set_ylabel('mae')
  217.     ax1.legend(['train', 'validation'], loc='upper left')

  218.     ax2.set_title('model loss')
  219.     ax2.plot(history.history['loss'])
  220.     ax2.plot(history.history['val_loss'])
  221.     ax2.set_xlabel('epoch')
  222.     ax2.set_ylabel('loss')
  223.     ax2.legend(['train', 'validation'], loc='upper left')

  224.     plt.show()

  225.     plot_learning_curve(history)

  226.     with open('./trainHistoryDict.pickle', 'wb') as file_pickle:
  227.         pickle.dump(history.history, file_pickle)

  228.     print('完成 Done!')


  229. def train_model():
  230.     data_path = 'F:\\AI_Study_dfy\\项目:自动驾驶之方向盘转动角度预测\data\\'
  231.     x_, y_ = load_train_datasets(data_path)
  232.     # 使用20%的数据作为测试数据集
  233.     test_ratio = 0.2
  234.     x_, y_ = shuffle(x_, y_)
  235.     X_train, X_validation, y_train, y_validation = train_test_split(x_, y_, random_state=SEED, test_size=test_ratio)
  236.     # 训练集大小:6428, 验证集大小:1608  = 8036
  237.     print('训练集大小:{}, 验证集大小:{}'.format(len(X_train), len(X_validation)))

  238.     # X_train_norm = preprocess_features_dfy(X_train)
  239.     # y_train = utils.to_categorical(y_train, n_classes)
  240.     input_shape = (128, 128, 3)
  241.     model = get_model(input_shape)
  242.     # 生成模型结构汇总 挺好!
  243.     model.summary()
  244.     # 模型保存结构文件
  245.     with gfile.GFile('model_structure_dfy.json', 'w') as f:
  246.         f.write(model.to_json())
  247.     train(model, input_shape, X_train, y_train, X_validation, y_validation, data_path)


  248. if __name__ == '__main__':

  249.     train_model()

复制代码


01.png (80.86 KB, 下载次数: 301)

01.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

117

帖子

258

积分

中级会员

Rank: 3Rank: 3

积分
258
QQ
沙发
发表于 2020-2-3 15:41:38 | 只看该作者
谢谢老师提供的资料。
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-19 14:13 , Processed in 0.191413 second(s), 22 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表