|
- # -*- coding: utf-8 -*-
- __author__ = u'东方耀 微信:dfy_88888'
- __date__ = '2019/10/28 下午3:37'
- __product__ = 'PyCharm'
- __filename__ = 'tf_keras_classify_model'
- import os
- import sys
- import time
- import tensorflow as tf
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import numpy as np
- import pandas as pd
- import sklearn
- from tensorflow import keras
- # tf的keras 和 单独的keras
- # pip3 install keras
- import keras
- print(tf.__version__)
- print(sys.version_info)
- for module in mpl, np, pd, sklearn, tf, keras:
- print(module.__name__, module.__version__)
- # https://tensorflow.google.cn/api_docs/python/tf/keras/datasets/fashion_mnist/load_data
- fashion_mnist = keras.datasets.fashion_mnist
- (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
- print(X_train.shape, y_train.shape)
- print('测试数据集:', X_test.shape, y_test.shape)
- X_valid, X_train = X_train[:5000], X_train[5000:]
- y_valid, y_train = y_train[:5000], y_train[5000:]
- print('验证数据集:', X_valid.shape, y_valid.shape)
- print('训练数据集:', X_train.shape, y_train.shape)
- def show_single_image(img_arr):
- plt.imshow(img_arr, cmap='binary')
- plt.show()
- # show_single_image(X_train[11])
- def show_multi_images(n_rows, n_cols, X_data, y_data, class_names):
- assert len(X_data) == len(y_data), '样本的特征与标签长度一致'
- assert n_rows*n_cols <= len(X_data)
- # 使用plt的子图 width, height in inches
- plt.figure(figsize=(n_cols*1.4, n_rows*1.6))
- for i in range(n_rows):
- for j in range(n_cols):
- index = i*n_cols + j
- plt.subplot(n_rows, n_cols, index+1)
- plt.imshow(X_data[index], cmap='binary', interpolation='nearest')
- plt.axis('off')
- plt.title(label=class_names[y_data[index]])
- plt.show()
- class_names = ['t-shirt', 'trouser', 'pullover', 'dress',
- 'coat', 'sandal', 'shirt', 'sneaker',
- 'bag', 'ankle boot']
- # show_multi_images(3, 5, X_train[:20], y_train[:20], class_names)
- # tf.keras.Sequential 构建模型结构
- # model = keras.Sequential(name='tf2.0_model')
- # model.add(keras.layers.Flatten(input_shape=(28, 28)))
- # model.add(keras.layers.Dense(units=300, activation='relu'))
- # model.add(keras.layers.Dense(units=100, activation='relu'))
- # model.add(keras.layers.Dense(10, activation='softmax'))
- model = keras.Sequential(layers=[
- keras.layers.Flatten(input_shape=(28, 28)),
- keras.layers.Dense(units=300, activation='relu'),
- keras.layers.Dense(units=100, activation='relu'),
- keras.layers.Dense(10, activation='softmax')
- ], name='tf2.0_model_2')
- model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
- loss='sparse_categorical_crossentropy',
- metrics=['accuracy', 'mse'])
- model.summary()
- # print(tf.test.is_gpu_available())
- # 开始训练
- # validation_freq 是根据epochs来的 每隔多少轮进行验证
- history = model.fit(x=X_train, y=y_train, epochs=10, verbose=1,
- validation_data=(X_valid, y_valid), validation_freq=1)
- print(type(history))
- # 模型训练过程中的历史数据指标
- print(history.history)
- # validation_freq=2 的情况:
- # {'loss': [2.276887303126942, 0.5237461835861206, 0.45976711897850037, 0.4327715507962487],
- # 'accuracy': [0.75996363, 0.8186182, 0.8376727, 0.8458727],
- # 'mse': [27.68526, 27.679394, 27.680891, 27.681673],
- # 'val_loss': [0.5112990405797958, 0.45343393814563754],
- # 'val_accuracy': [0.823199987411499, 0.843999981880188],
- # 'val_mse': [27.655563354492188, 27.657318115234375]}
- def plot_learning_curve(history):
- # ValueError: arrays must all be same length
- # 表格型数据 要求每一列的len一致 这里即:history.history字典里每个key对应的value长度一致
- df_history = pd.DataFrame(data=history.history)
- print(df_history)
- print(df_history.index)
- print(df_history.columns)
- print(df_history.dtypes)
- df_history.plot(figsize=(8, 5))
- plt.grid(True)
- # x就是DataFrame的索引
- plt.ylim(0, 1.3)
- plt.show()
- plot_learning_curve(history)
复制代码
东方老师AI官网:http://www.ai111.vip
有任何问题可联系东方老师微信:dfy_88888
【微信二维码图片】
|
-
01.png
(226.43 KB, 下载次数: 81)
-
02.png
(313.73 KB, 下载次数: 85)
-
03.png
(199.38 KB, 下载次数: 82)
-
04.png
(223.92 KB, 下载次数: 88)
-
05.png
(255.58 KB, 下载次数: 84)
-
06.png
(259.61 KB, 下载次数: 84)
-
07.png
(119.84 KB, 下载次数: 89)
-
08.png
(92.93 KB, 下载次数: 86)
-
屏幕截图.png
(33.01 KB, 下载次数: 85)
-
01.png
(218.32 KB, 下载次数: 89)
-
02.png
(112.03 KB, 下载次数: 86)
-
03.png
(257.55 KB, 下载次数: 86)
|