东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2098|回复: 0

[学习笔记] 03、基于Keras的人工神经网络解决鸢尾花分类问题的代码实现

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
发表于 2019-10-21 16:12:46 | 显示全部楼层 |阅读模式
03、基于Keras的人工神经网络解决鸢尾花分类问题的代码实现


准确率达到了98%


神经网络分类问题的经典数据(集鸢尾花数据集)介绍,神经网络Python库Keras的介绍
使用Pandas读取鸢尾花数据集, 使用LabelEncoder对类别标签进行编码
使用Keras创建一个用于鸢尾花分类识别的神经网络
训练用于鸢尾花分类的神经网络 解读训练输出的日志 了解如何评价神经网络的性能




ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].
TypeError: 'int' object is not iterable


损失函数并不使用测试数据(test/validation data)来衡量网络的性能
损失函数用来指导网络的训练过程!使得网络的参数向损失降低的方向改变


分类问题的损失函数:hinge loss (SVM里用到)   cross-entropy loss 交叉熵


One-Hot encoding 哑编码 独热编码
非概率的解释:hinge loss
概率解释:将输出转换为概率函数 softmax  probability


示例:手动计算3分类问题的accuracy、交叉熵损失、mse均方差损失


东方老师AI官网:http://www.ai111.vip
有任何问题可联系东方老师微信:dfy_88888
【微信二维码图片】
  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/10/21 9:38'
  4. __product__ = 'PyCharm'
  5. __filename__ = '人工神经网络解决鸢尾花分类'

  6. import numpy as np
  7. import pandas as pd
  8. from keras.models import Sequential
  9. from keras.layers import Dense
  10. from keras.optimizers import Adam, SGD
  11. from keras.wrappers.scikit_learn import KerasClassifier
  12. from keras.utils import np_utils
  13. from sklearn.model_selection import cross_val_score
  14. # k 折 交叉验证
  15. from sklearn.model_selection import KFold
  16. # 将分类的字符串 变成 数字 0 1 2 三个类别
  17. from sklearn.preprocessing import LabelEncoder
  18. # 保存模型为json 利用json文件预测
  19. from keras.models import model_from_json
  20. from sklearn import datasets
  21. from sklearn.model_selection import train_test_split
  22. from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
  23. from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score
  24. from sklearn.metrics import mean_squared_error, mean_absolute_error

  25. SEED = 666
  26. np.random.seed(SEED)
  27. # load data
  28. # header=0(default)表示第一行为标题行   header=None时表示原始文件没有表头数据
  29. df = pd.read_csv(filepath_or_buffer='datasets/iris.csv', header=0)
  30. # iris = datasets.load_iris()
  31. # print(iris.keys())
  32. X = df.values[:, 1:5].astype(float)
  33. y = df.values[:, 5]

  34. print(X.shape, y.ndim)
  35. print(X[:3])
  36. print(y[:10])
  37. encoder = LabelEncoder()
  38. y = encoder.fit_transform(y)
  39. print(y, y.ndim)
  40. Y_onehot = np_utils.to_categorical(y, num_classes=3)
  41. print(Y_onehot.shape, Y_onehot.ndim)
  42. print(Y_onehot[:5, :])

  43. # X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=SEED)
  44. # log_reg = LogisticRegressionCV(cv=5, n_jobs=-1, verbose=0, multi_class='ovr')
  45. # log_reg = LogisticRegression()
  46. # log_reg.fit(X_train, y_train)
  47. # y_test_predict = log_reg.predict(X_test)
  48. # print(log_reg.score(X_test, y_test))
  49. # print(accuracy_score(y_test, y_test_predict)) # 0.9473684210526315
  50. # print(confusion_matrix(y_test, y_test_predict))
  51. # print(precision_score(y_test, y_test_predict, average='micro'))
  52. # print(recall_score(y_test, y_test_predict, average='micro'))


  53. # 定义神经网络结构
  54. def baseline_model():
  55.     model = Sequential(name='baseline_model_dfy')
  56.     # input_dim=4 or input_shape=(4,) 都是可以的
  57.     model.add(Dense(units=7, activation='tanh', input_dim=4))
  58.     model.add(Dense(units=3, activation='softmax'))
  59.     # 分类问题的损失 交叉熵 'categorical_crossentropy' accuracy 92%  hinge loss
  60.     # 回归问题的损失  'mean_squared_error' accuracy 71.3%
  61.     # epochs 从20 调大到 200  accuracy 92% ---> 98%
  62.     # 隐藏层神经元units=7 调大到 128  accuracy 98% ---> 98% 还是恢复为7
  63.     model.compile(optimizer=SGD(learning_rate=0.1), loss='categorical_crossentropy', metrics=['accuracy'])
  64.     return model

  65. # from keras.wrappers.scikit_learn import KerasClassifier
  66. estimator = KerasClassifier(build_fn=baseline_model, epochs=20, batch_size=10, verbose=1)

  67. # evalute
  68. kfold = KFold(n_splits=10, shuffle=True, random_state=SEED)
  69. result = cross_val_score(estimator, X, Y_onehot, cv=kfold)
  70. print(baseline_model().summary())
  71. print('Accuracy of cv:', result)
  72. print('Accuracy of CV mean:%.2f, std:%.3f' % (result.mean(), result.std()))

  73. # train model
  74. estimator.fit(X, Y_onehot)
  75. print('model accuracy score:', estimator.score(X, Y_onehot))


  76. # save model
  77. model_json = estimator.model.to_json()
  78. with open('model.json', 'w') as json_file:
  79.     json_file.write(model_json)
  80. estimator.model.save_weights('model_weights.h5')
  81. print('saved保存了 model to disk! Done')

  82. # laod model and use it for predict
  83. with open('model.json', 'r') as file:
  84.     loaded_model_json = file.read()

  85. loaded_model = model_from_json(loaded_model_json)
  86. loaded_model.load_weights('model_weights.h5')
  87. print('loaded加载了 model from disk! Done')

  88. predicted = loaded_model.predict(X)
  89. print('predicted probability:', predicted)
  90. predicted_label = loaded_model.predict_classes(X)
  91. print('predicted预测的 label:\n' + str(predicted_label))
  92. print('实际值 label:\n' + str(y))
  93. print('预测的准确度:', accuracy_score(y, predicted_label))
  94. print(confusion_matrix(y, predicted_label))
  95. print('预测的精准率:', precision_score(y, predicted_label, average='micro'))
  96. print('预测的召回率:', recall_score(y, predicted_label, average='micro'))
  97. print('预测的F1 score:', f1_score(y, predicted_label, average='micro'))
  98. # print('roc_auc(area under curve):', roc_auc_score(y, y_score=?))



复制代码


01tanh.png
02softmax.png
03softmax.png
04softmax.png
05.png
06.png
07.png
08.png
09.png
10.png
11.png
12.png
13.png
14.png
15.png

iris.csv

4.86 KB, 阅读权限: 10, 下载次数: 1

让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-18 18:47 , Processed in 0.202240 second(s), 22 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表