人工智能视频教程 ai vip技术 人工智能数学基础 爬虫 python机器学习 tensorflow深度学习 20+个企业AI实战项目

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 240|回复: 1

[课堂笔记] 02、训练数据与基于Keras的模型原始架构源码

[复制链接]

1040

主题

1327

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
11448
QQ
发表于 2019-11-4 20:18:15 | 显示全部楼层 |阅读模式
02、训练数据与基于Keras的模型原始架构源码


训练数据下载地址:链接:https://pan.baidu.com/s/1eiWnp3BVZnT8q7T24vpBkA
提取码:nqm0


对数据集的分析,分布等


原始框架源码(可以跑起来,但是准确率非常低 5%左右,而且模型不收敛 需要继续优化 调整):
  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/11/4 20:06'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'train_original'
  6. from keras.layers import Flatten
  7. from keras.layers import Conv2D
  8. from keras.layers import MaxPooling2D
  9. from keras.models import Model
  10. from keras.layers import Input
  11. from keras.layers import Dense
  12. from keras.layers import Dropout
  13. import cv2
  14. from sklearn.model_selection import train_test_split
  15. from keras.preprocessing.image import ImageDataGenerator
  16. from keras.callbacks import ModelCheckpoint, EarlyStopping
  17. from keras import utils
  18. import numpy as np
  19. import pandas as pd
  20. import matplotlib.pyplot as plt
  21. import pickle
  22. from tensorflow import gfile


  23. """
  24. 要提高识别率,至少有两个大方向:
  25. 1、数据本身 ()
  26. 2、神经网络结构
  27. """


  28. np.random.seed(666)


  29. def preprocess_features(X):
  30.     # convert from RGB to YUV
  31.     X = np.array([np.expand_dims(cv2.cvtColor(rgb_img, cv2.COLOR_RGB2YUV)[:, :, 0], 2) for rgb_img in X])
  32.     return X


  33. def show_samples_from_generator(image_datagen, X_train, y_train):
  34.     # take a random image from the training set
  35.     img_rgb = X_train[10]

  36.     # plot the original image
  37.     plt.figure(figsize=(4, 4))
  38.     plt.imshow(img_rgb)
  39.     plt.title('Example of RGB image (class = {})'.format(y_train[10]))
  40.     plt.show()

  41.     # plot some randomly augmented images  data augmentation 数据增强
  42.     rows, cols = 4, 12
  43.     fig, ax_array = plt.subplots(rows, cols)
  44.     for ax in ax_array.ravel():
  45.         # y_train[0]  结果是:0
  46.         # y_train[0:1]  结果是:array([0], dtype=uint8)
  47.         # np.expand_dims(img_rgb, axis=0) 后的shape (1, 32, 32, 3)
  48.         augmented_img, _ = image_datagen.flow(np.expand_dims(img_rgb, axis=0), y_train[10:11]).next()
  49.         # np.squeeze 挤压
  50.         ax.imshow(np.uint8(np.squeeze(augmented_img)))
  51.     plt.setp([a.get_xticklabels() for a in ax_array.ravel()], visible=False)
  52.     plt.setp([a.get_yticklabels() for a in ax_array.ravel()], visible=False)
  53.     plt.suptitle('Random examples of data augmentation (starting from the previous image)')
  54.     plt.show()


  55. def get_image_generator():
  56.     # create the generator to perform online data augmentation
  57.     image_datagen = ImageDataGenerator(rotation_range=15.)
  58.     return image_datagen


  59. def get_model(dropout_rate = 0.0):
  60.     input_shape = (32, 32, 1)
  61.     input = Input(shape=input_shape)
  62.     cv2d_1 = Conv2D(64, (3, 3), padding='same', activation='relu')(input)
  63.     pool_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(cv2d_1)
  64.     dropout_1 = Dropout(dropout_rate)(pool_1)
  65.     flatten_1 = Flatten()(dropout_1)

  66.     dense_1 = Dense(64, activation='relu')(flatten_1)
  67.     output = Dense(43, activation='softmax')(dense_1)
  68.     model = Model(inputs=input, outputs=output, name='model_original')
  69.     # compile model
  70.     model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
  71.     # summarize model
  72.     model.summary()
  73.     return model


  74. def plot_learning_curve(history):
  75.     # ValueError: arrays must all be same length
  76.     # 表格型数据 要求每一列的len一致 这里即:history.history字典里每个key对应的value长度一致
  77.     df_history = pd.DataFrame(data=history.history)
  78.     print(df_history)
  79.     # print(df_history.index)
  80.     print(df_history.columns)
  81.     # print(df_history.dtypes)
  82.     df_history.plot(figsize=(8, 5))
  83.     plt.grid(True)
  84.     # x就是DataFrame的索引
  85.     plt.ylim(0, 1.5)
  86.     plt.show()


  87. def train(model, image_datagen, x_train, y_train, x_validation, y_validation):
  88.     # checkpoint
  89.     filepath = "weights.best.hdf5"
  90.     cb_save_best_model = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, save_best_only=True, mode='max')
  91.     cb_early_stop = EarlyStopping(monitor='val_accuracy', min_delta=1e-2, patience=5)

  92.     callbacks_list = [cb_save_best_model, cb_early_stop]
  93.     image_datagen.fit(x_train, seed=666)
  94.     print('训练数据集个数:', len(x_train))
  95.     history = model.fit_generator(image_datagen.flow(x_train, y_train, batch_size=128),
  96.                         steps_per_epoch=500,
  97.                         validation_data=(x_validation, y_validation),
  98.                         epochs=10,
  99.                         callbacks=callbacks_list,
  100.                         verbose=1)

  101.     # list all data in history
  102.     print(history.history.keys())
  103.     fig, ax_array = plt.subplots(1, 2)
  104.     ax1, ax2 = ax_array

  105.     ax1.set_title('model accuracy')
  106.     ax1.plot(history.history['accuracy'])
  107.     ax1.plot(history.history['val_accuracy'])
  108.     ax1.set_xlabel('epoch')
  109.     ax1.set_ylabel('accuracy')
  110.     ax1.legend(['train', 'validation'], loc='upper left')

  111.     ax2.set_title('model loss')
  112.     ax2.plot(history.history['loss'])
  113.     ax2.plot(history.history['val_loss'])
  114.     ax2.set_xlabel('epoch')
  115.     ax2.set_ylabel('loss')
  116.     ax2.legend(['train', 'validation'], loc='upper left')

  117.     plt.show()

  118.     plot_learning_curve(history)

  119.     with open('/trainHistoryDict.p', 'wb') as file_pi:
  120.         pickle.dump(history.history, file_pi)
  121.     return history


  122. def evaluate(model, X_test, y_test):
  123.     score = model.evaluate(X_test, y_test, verbose=1)
  124.     accuracy = score[1]
  125.     return accuracy


  126. def load_traffic_sign_data(training_file):
  127.     with open(training_file, mode='rb') as f:
  128.         train = pickle.load(f)

  129.     X_train, y_train = train['features'], train['labels']

  130.     return X_train, y_train


  131. def train_model():
  132.     X_train, y_train = load_traffic_sign_data('./traffic-signs-data/train.p')

  133.     # Number of examples
  134.     n_train = X_train.shape[0]

  135.     # What's the shape of an traffic sign image?
  136.     image_shape = X_train[0].shape

  137.     # How many classes?
  138.     n_classes = np.unique(y_train).shape[0]

  139.     print("训练数据集的数据个数 =", n_train)
  140.     print("图像尺寸  =", image_shape)
  141.     print("类别数量 =", n_classes)
  142.     print('X_train.shape:', X_train.shape)
  143.     print('y_train.shape:', y_train.shape)

  144.     X_train_norm = preprocess_features(X_train)
  145.     # one-hot
  146.     y_train = utils.to_categorical(y_train, n_classes)

  147.     # split into train and validation
  148.     VAL_RATIO = 0.2
  149.     X_train_norm, X_val_norm, y_train, y_val = train_test_split(X_train_norm, y_train,
  150.                                                                 test_size=VAL_RATIO,
  151.                                                                 random_state=666)

  152.     model = get_model(0.0)
  153.     # 模型保存结构文件
  154.     with gfile.GFile('model_structure.json', 'w') as f:
  155.         f.write(model.to_json())
  156.     image_generator = get_image_generator()
  157.     train(model, image_generator, X_train_norm, y_train, X_val_norm, y_val)


  158. if __name__ == "__main__":
  159.     train_model()
复制代码






01.png
02.png
03.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

96

帖子

202

积分

中级会员

Rank: 3Rank: 3

积分
202
发表于 2019-12-12 18:21:09 | 显示全部楼层
this is good idea
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备18018285号-1 )

GMT+8, 2020-6-6 09:28 , Processed in 0.238388 second(s), 22 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表