|
02、训练数据与基于Keras的模型原始架构源码
训练数据下载地址:链接:https://pan.baidu.com/s/1eiWnp3BVZnT8q7T24vpBkA
提取码:nqm0
对数据集的分析,分布等
原始框架源码(可以跑起来,但是准确率非常低 5%左右,而且模型不收敛 需要继续优化 调整):
- # -*- coding: utf-8 -*-
- __author__ = u'东方耀 微信:dfy_88888'
- __date__ = '2019/11/4 20:06'
- __product__ = 'PyCharm'
- __filename__ = 'train_original'
- from keras.layers import Flatten
- from keras.layers import Conv2D
- from keras.layers import MaxPooling2D
- from keras.models import Model
- from keras.layers import Input
- from keras.layers import Dense
- from keras.layers import Dropout
- import cv2
- from sklearn.model_selection import train_test_split
- from keras.preprocessing.image import ImageDataGenerator
- from keras.callbacks import ModelCheckpoint, EarlyStopping
- from keras import utils
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- import pickle
- from tensorflow import gfile
- """
- 要提高识别率,至少有两个大方向:
- 1、数据本身 ()
- 2、神经网络结构
- """
- np.random.seed(666)
- def preprocess_features(X):
- # convert from RGB to YUV
- X = np.array([np.expand_dims(cv2.cvtColor(rgb_img, cv2.COLOR_RGB2YUV)[:, :, 0], 2) for rgb_img in X])
- return X
- def show_samples_from_generator(image_datagen, X_train, y_train):
- # take a random image from the training set
- img_rgb = X_train[10]
- # plot the original image
- plt.figure(figsize=(4, 4))
- plt.imshow(img_rgb)
- plt.title('Example of RGB image (class = {})'.format(y_train[10]))
- plt.show()
- # plot some randomly augmented images data augmentation 数据增强
- rows, cols = 4, 12
- fig, ax_array = plt.subplots(rows, cols)
- for ax in ax_array.ravel():
- # y_train[0] 结果是:0
- # y_train[0:1] 结果是:array([0], dtype=uint8)
- # np.expand_dims(img_rgb, axis=0) 后的shape (1, 32, 32, 3)
- augmented_img, _ = image_datagen.flow(np.expand_dims(img_rgb, axis=0), y_train[10:11]).next()
- # np.squeeze 挤压
- ax.imshow(np.uint8(np.squeeze(augmented_img)))
- plt.setp([a.get_xticklabels() for a in ax_array.ravel()], visible=False)
- plt.setp([a.get_yticklabels() for a in ax_array.ravel()], visible=False)
- plt.suptitle('Random examples of data augmentation (starting from the previous image)')
- plt.show()
- def get_image_generator():
- # create the generator to perform online data augmentation
- image_datagen = ImageDataGenerator(rotation_range=15.)
- return image_datagen
- def get_model(dropout_rate = 0.0):
- input_shape = (32, 32, 1)
- input = Input(shape=input_shape)
- cv2d_1 = Conv2D(64, (3, 3), padding='same', activation='relu')(input)
- pool_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(cv2d_1)
- dropout_1 = Dropout(dropout_rate)(pool_1)
- flatten_1 = Flatten()(dropout_1)
- dense_1 = Dense(64, activation='relu')(flatten_1)
- output = Dense(43, activation='softmax')(dense_1)
- model = Model(inputs=input, outputs=output, name='model_original')
- # compile model
- model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
- # summarize model
- model.summary()
- return model
- def plot_learning_curve(history):
- # ValueError: arrays must all be same length
- # 表格型数据 要求每一列的len一致 这里即:history.history字典里每个key对应的value长度一致
- df_history = pd.DataFrame(data=history.history)
- print(df_history)
- # print(df_history.index)
- print(df_history.columns)
- # print(df_history.dtypes)
- df_history.plot(figsize=(8, 5))
- plt.grid(True)
- # x就是DataFrame的索引
- plt.ylim(0, 1.5)
- plt.show()
- def train(model, image_datagen, x_train, y_train, x_validation, y_validation):
- # checkpoint
- filepath = "weights.best.hdf5"
- cb_save_best_model = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, save_best_only=True, mode='max')
- cb_early_stop = EarlyStopping(monitor='val_accuracy', min_delta=1e-2, patience=5)
- callbacks_list = [cb_save_best_model, cb_early_stop]
- image_datagen.fit(x_train, seed=666)
- print('训练数据集个数:', len(x_train))
- history = model.fit_generator(image_datagen.flow(x_train, y_train, batch_size=128),
- steps_per_epoch=500,
- validation_data=(x_validation, y_validation),
- epochs=10,
- callbacks=callbacks_list,
- verbose=1)
- # list all data in history
- print(history.history.keys())
- fig, ax_array = plt.subplots(1, 2)
- ax1, ax2 = ax_array
- ax1.set_title('model accuracy')
- ax1.plot(history.history['accuracy'])
- ax1.plot(history.history['val_accuracy'])
- ax1.set_xlabel('epoch')
- ax1.set_ylabel('accuracy')
- ax1.legend(['train', 'validation'], loc='upper left')
- ax2.set_title('model loss')
- ax2.plot(history.history['loss'])
- ax2.plot(history.history['val_loss'])
- ax2.set_xlabel('epoch')
- ax2.set_ylabel('loss')
- ax2.legend(['train', 'validation'], loc='upper left')
- plt.show()
- plot_learning_curve(history)
- with open('/trainHistoryDict.p', 'wb') as file_pi:
- pickle.dump(history.history, file_pi)
- return history
- def evaluate(model, X_test, y_test):
- score = model.evaluate(X_test, y_test, verbose=1)
- accuracy = score[1]
- return accuracy
- def load_traffic_sign_data(training_file):
- with open(training_file, mode='rb') as f:
- train = pickle.load(f)
- X_train, y_train = train['features'], train['labels']
- return X_train, y_train
- def train_model():
- X_train, y_train = load_traffic_sign_data('./traffic-signs-data/train.p')
- # Number of examples
- n_train = X_train.shape[0]
- # What's the shape of an traffic sign image?
- image_shape = X_train[0].shape
- # How many classes?
- n_classes = np.unique(y_train).shape[0]
- print("训练数据集的数据个数 =", n_train)
- print("图像尺寸 =", image_shape)
- print("类别数量 =", n_classes)
- print('X_train.shape:', X_train.shape)
- print('y_train.shape:', y_train.shape)
- X_train_norm = preprocess_features(X_train)
- # one-hot
- y_train = utils.to_categorical(y_train, n_classes)
- # split into train and validation
- VAL_RATIO = 0.2
- X_train_norm, X_val_norm, y_train, y_val = train_test_split(X_train_norm, y_train,
- test_size=VAL_RATIO,
- random_state=666)
- model = get_model(0.0)
- # 模型保存结构文件
- with gfile.GFile('model_structure.json', 'w') as f:
- f.write(model.to_json())
- image_generator = get_image_generator()
- train(model, image_generator, X_train_norm, y_train, X_val_norm, y_val)
- if __name__ == "__main__":
- train_model()
复制代码
|
|