东方耀AI技术分享

标题: 05、除Sequential外模型构建方法:函数式API与子类API [打印本页]

作者: 东方耀    时间: 2019-10-30 11:06
标题: 05、除Sequential外模型构建方法:函数式API与子类API
05、除Sequential外模型构建方法:函数式API与子类API


Wide与Deep模型实战


子类API
功能API(函数式API)
多输入与多输出


东方老师AI官网:http://www.ai111.vip
有任何问题可联系东方老师微信:dfy_88888
【微信二维码图片】


  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/10/30 10:10'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'tf-keras-regression-wide&deep'


  6. import numpy as np
  7. import matplotlib as mpl
  8. import matplotlib.pyplot as plt
  9. import sklearn
  10. import pandas as pd
  11. import os
  12. import sys
  13. import time
  14. import tensorflow as tf
  15. from tensorflow import keras

  16. print(sys.version_info)
  17. for module in mpl, np, pd, sklearn, tf, keras:
  18.     print(module.__name__, module.__version__)

  19. from sklearn.datasets import fetch_california_housing

  20. housing = fetch_california_housing()
  21. # print(housing.DESCR)
  22. print(housing.data.shape)
  23. print(housing.target.shape)

  24. from sklearn.model_selection import train_test_split

  25. x_train_all, x_test, y_train_all, y_test = train_test_split(
  26.     housing.data, housing.target, random_state=7)
  27. x_train, x_valid, y_train, y_valid = train_test_split(
  28.     x_train_all, y_train_all, random_state=11)
  29. print(x_train.shape, y_train.shape)
  30. print(x_valid.shape, y_valid.shape)
  31. print(x_test.shape, y_test.shape)

  32. from sklearn.preprocessing import StandardScaler

  33. scaler = StandardScaler()
  34. x_train_scaled = scaler.fit_transform(x_train)
  35. x_valid_scaled = scaler.transform(x_valid)
  36. x_test_scaled = scaler.transform(x_test)


  37. def create_model_function():
  38.     # 第一种方法:函数式api实现wide&deep模型 像使用函数一样
  39.     input_1 = keras.layers.Input(shape=x_train.shape[1])
  40.     # 复合函数的形式
  41.     hidden_1 = keras.layers.Dense(30, activation='relu')(input_1)
  42.     hidden_2 = keras.layers.Dense(30, activation='relu')(hidden_1)
  43.     # wide model
  44.     concat = keras.layers.concatenate(inputs=[input_1, hidden_2])
  45.     output = keras.layers.Dense(1)(concat)

  46.     model = keras.models.Model(inputs=[input_1], outputs=[output], name='wide_deep_model')
  47.     return model


  48. class WideDeepModel(keras.models.Model):
  49.     # 第二种方法:子类api实现wide & deep模型
  50.     def __init__(self):
  51.         super(WideDeepModel, self).__init__()
  52.         # 定义模型的层
  53.         self.hidden1_layer = keras.layers.Dense(30, activation='relu')
  54.         self.hidden2_layer = keras.layers.Dense(30, activation='relu')
  55.         self.output_layer = keras.layers.Dense(1)

  56.     def call(self, input):
  57.         # 完成模型的正向计算
  58.         hidden1 = self.hidden1_layer(input)
  59.         hidden2 = self.hidden2_layer(hidden1)
  60.         concat = keras.layers.concatenate([input, hidden2])
  61.         output = self.output_layer(concat)
  62.         return output


  63. # 第一种方法
  64. # model = create_model_function()
  65. # 第二种方法(类的实例化对象 build方式指定input_shape)
  66. model = WideDeepModel()
  67. model.build(input_shape=(None, x_train.shape[1]))

  68. model.summary()

  69. model.compile(optimizer=keras.optimizers.Adam(0.001), loss='mean_squared_error')
  70. callbacks = [
  71.     keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)
  72. ]
  73. history = model.fit(x_train_scaled, y_train, validation_data=(x_valid_scaled, y_valid), epochs=50,
  74.                     callbacks=callbacks)
  75. print(type(history))
  76. print(history.history)


  77. def plot_learning_curve(history):
  78.     # ValueError: arrays must all be same length
  79.     # 表格型数据 要求每一列的len一致 这里即:history.history字典里每个key对应的value长度一致
  80.     df_history = pd.DataFrame(data=history.history)
  81.     print(df_history)
  82.     print(df_history.index)
  83.     print(df_history.columns)
  84.     print(df_history.dtypes)
  85.     df_history.plot(figsize=(8, 5))
  86.     plt.grid(True)
  87.     # x就是DataFrame的索引
  88.     plt.ylim(0, 1)
  89.     plt.show()


  90. plot_learning_curve(history)
  91. # Returns the loss value & metrics values for the model in test mode.
  92. print(model.evaluate(x_test_scaled, y_test))
复制代码








欢迎光临 东方耀AI技术分享 (http://www.ai111.vip/) Powered by Discuz! X3.4