东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 6280|回复: 18
打印 上一主题 下一主题

[课堂笔记] 14、解决OverFitting的方案:L1正则、L2正则、弹性网络算法_...

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14439
QQ
跳转到指定楼层
楼主
发表于 2018-4-8 16:57:01 | 只看该作者 |只看大图 回帖奖励 |正序浏览 |阅读模式


14、解决OverFitting的方案:L1正则、L2正则、弹性网络算法_笔记

  1. # -*- coding: utf-8 -*-
  2. __author__ = 'dongfangyao'
  3. __date__ = '2018/4/8 下午3:47'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'overfitting1'

  6. # 为了解决过拟合问题:我们可以选择在损失函数中加入正则项(惩罚项),对于系数过大的惩罚
  7. # 对于系数过多也有一定的惩罚能力 主要分为L1-norm 与 L2-norm
  8. # LASSO 可以产生稀疏解  主要用于特征选择(去掉冗余与无用的特征属性) 而且速度比较快

  9. import numpy as np
  10. import matplotlib as mpl
  11. import matplotlib.pyplot as plt
  12. import pandas as pd
  13. import warnings
  14. import sklearn

  15. from sklearn.linear_model import LinearRegression, LassoCV, RidgeCV, ElasticNetCV

  16. from sklearn.preprocessing import PolynomialFeatures
  17. from sklearn.pipeline import Pipeline
  18. from sklearn.linear_model.coordinate_descent import ConvergenceWarning

  19. ## 设置字符集,防止中文乱码
  20. mpl.rcParams['font.sans-serif']=[u'simHei']
  21. mpl.rcParams['axes.unicode_minus']=False
  22. ## 拦截异常
  23. warnings.filterwarnings(action = 'ignore', category=ConvergenceWarning)

  24. ## 创建模拟数据
  25. np.random.seed(100)
  26. #显示方式设置,每行的字符数用于插入换行符,是否使用科学计数法
  27. np.set_printoptions(linewidth=1000, suppress=True)

  28. N = 10
  29. x = np.linspace(0, 6, N) + np.random.randn(N)
  30. y = 1.8*x**3 + x**2 - 14*x - 7 + np.random.randn(N)
  31. ## 将其设置为矩阵
  32. x.shape = -1, 1
  33. y.shape = -1, 1

  34. ## RidgeCV和Ridge的区别是:前者可以进行交叉验证
  35. models = [
  36.     Pipeline([
  37.             ('Poly', PolynomialFeatures(include_bias=False)),
  38.             ('Linear', LinearRegression(fit_intercept=False))
  39.         ]),
  40.     Pipeline([
  41.             ('Poly', PolynomialFeatures(include_bias=False)),
  42.             # alpha给定的是Ridge算法中,L2正则项的权重值,也就是ppt中的兰姆达
  43.             # alphas是给定CV交叉验证过程中,Ridge算法的alpha参数值的取值的范围
  44.             ('Linear', RidgeCV(alphas=np.logspace(-3,2,50), fit_intercept=False))
  45.         ]),
  46.     Pipeline([
  47.             ('Poly', PolynomialFeatures(include_bias=False)),
  48.             ('Linear', LassoCV(alphas=np.logspace(0,1,10), fit_intercept=False))
  49.         ]),
  50.     Pipeline([
  51.             ('Poly', PolynomialFeatures(include_bias=False)),
  52.             # l1_ratio:给定EN算法中L1正则项在整个惩罚项中的比例,这里给定的是一个列表;
  53.             # l1_ratio:也就是ppt中的p  p的范围是[0, 1]
  54.             # alphas也就是ppt中的兰姆达
  55.             # alphas表示的是在CV交叉验证的过程中,EN算法L1正则项的权重比例的可选值的范围
  56.             ('Linear', ElasticNetCV(alphas=np.logspace(0,1,10), l1_ratio=[.1, .5, .7, .9, .95, 1], fit_intercept=False))
  57.         ])
  58. ]

  59. ## 线性模型过拟合图形识别
  60. plt.figure(facecolor='w')
  61. degree = np.arange(1, N, 4)  # 阶
  62. dm = degree.size
  63. colors = []  # 颜色
  64. for c in np.linspace(16711680, 255, dm):
  65.     colors.append('#%06x' % int(c))

  66. model = models[0]
  67. for i, d in enumerate(degree):
  68.     plt.subplot(int(np.ceil(dm / 2.0)), 2, i + 1)
  69.     plt.plot(x, y, 'ro', ms=10, zorder=N)

  70.     # 设置阶数
  71.     model.set_params(Poly__degree=d)
  72.     # 模型训练
  73.     model.fit(x, y.ravel())

  74.     lin = model.get_params('Linear')['Linear']
  75.     output = u'%d阶,系数为:' % (d)
  76.     # 判断lin对象中是否有对应的属性
  77.     if hasattr(lin, 'alpha_'):
  78.         idx = output.find(u'系数')
  79.         output = output[:idx] + (u'alpha=%.6f, ' % lin.alpha_) + output[idx:]
  80.     if hasattr(lin, 'l1_ratio_'):
  81.         idx = output.find(u'系数')
  82.         output = output[:idx] + (u'l1_ratio=%.6f, ' % lin.l1_ratio_) + output[idx:]
  83.     print(output, lin.coef_.ravel())

  84.     x_hat = np.linspace(x.min(), x.max(), num=100)  ## 产生模拟数据
  85.     x_hat.shape = -1, 1
  86.     y_hat = model.predict(x_hat)
  87.     s = model.score(x, y)

  88.     z = N - 1 if (d == 2) else 0
  89.     label = u'%d阶, 正确率=%.3f' % (d, s)
  90.     plt.plot(x_hat, y_hat, color=colors[i], lw=2, alpha=0.75, label=label, zorder=z)

  91.     plt.legend(loc='upper left')
  92.     plt.grid(True)
  93.     plt.xlabel('X', fontsize=16)
  94.     plt.ylabel('Y', fontsize=16)

  95. plt.tight_layout(1, rect=(0, 0, 1, 0.95))
  96. plt.suptitle(u'线性回归过拟合显示', fontsize=22)
  97. plt.show()

  98. ## 线性回归、Lasso回归、Ridge回归、ElasticNet比较
  99. plt.figure(facecolor='w')
  100. degree = np.arange(1, N, 2)  # 阶, 多项式扩展允许给定的阶数
  101. dm = degree.size
  102. colors = []  # 颜色
  103. for c in np.linspace(16711680, 255, dm):
  104.     colors.append('#%06x' % int(c))
  105. titles = [u'线性回归', u'Ridge回归', u'Lasso回归', u'ElasticNet']

  106. for t in range(4):
  107.     model = models[t]  # 选择了模型--具体的pipeline(线性、Lasso、Ridge、EN)
  108.     plt.subplot(2, 2, t + 1)  # 选择具体的子图
  109.     plt.plot(x, y, 'ro', ms=10, zorder=N)  # 在子图中画原始数据点; zorder:图像显示在第几层

  110.     # 遍历不同的多项式的阶,看不同阶的情况下,模型的效果
  111.     for i, d in enumerate(degree):
  112.         # 设置阶数(多项式)
  113.         model.set_params(Poly__degree=d)
  114.         # 模型训练
  115.         model.fit(x, y.ravel())

  116.         # 获取得到具体的算法模型
  117.         # model.get_params()方法返回的其实是一个dict对象,后面的Linear其实是dict对应的key
  118.         # 也是我们在定义Pipeline的时候给定的一个名称值
  119.         lin = model.get_params()['Linear']
  120.         # 打印数据
  121.         output = u'%s:%d阶,系数为:' % (titles[t], d)
  122.         # 判断lin对象中是否有对应的属性
  123.         if hasattr(lin, 'alpha_'):  # 判断lin这个模型中是否有alpha_这个属性
  124.             idx = output.find(u'系数')
  125.             output = output[:idx] + (u'alpha=%.6f, ' % lin.alpha_) + output[idx:]
  126.         if hasattr(lin, 'l1_ratio_'):  # 判断lin这个模型中是否有l1_ratio_这个属性
  127.             idx = output.find(u'系数')
  128.             output = output[:idx] + (u'l1_ratio=%.6f, ' % lin.l1_ratio_) + output[idx:]
  129.         # line.coef_:获取线性模型的参数列表,也就是我们ppt中的theta值,ravel()将结果转换为1维数据
  130.         print(output, lin.coef_.ravel())

  131.         # 产生模拟数据
  132.         x_hat = np.linspace(x.min(), x.max(), num=100)  ## 产生模拟数据
  133.         x_hat.shape = -1, 1
  134.         # 数据预测
  135.         y_hat = model.predict(x_hat)
  136.         # 计算准确率
  137.         s = model.score(x, y)

  138.         # 当d等于5的时候,设置为N-1层,其它设置0层;将d=5的这条线凸显出来
  139.         z = N + 1 if (d == 5) else 0
  140.         label = u'%d阶, 正确率=%.3f' % (d, s)
  141.         plt.plot(x_hat, y_hat, color=colors[i], lw=2, alpha=0.75, label=label, zorder=z)

  142.     plt.legend(loc='upper left')
  143.     plt.grid(True)
  144.     plt.title(titles[t])
  145.     plt.xlabel('X', fontsize=16)
  146.     plt.ylabel('Y', fontsize=16)
  147. plt.tight_layout(1, rect=(0, 0, 1, 0.95))
  148. plt.suptitle(u'各种不同线性回归过拟合显示', fontsize=22)
  149. plt.show()


复制代码

1.png (95.04 KB, 下载次数: 241)

1.png

2.png (201.75 KB, 下载次数: 244)

2.png

解决OverFitting的方案1.png (78.43 KB, 下载次数: 248)

解决OverFitting的方案1.png

解决OverFitting的方案2.png (44.76 KB, 下载次数: 238)

解决OverFitting的方案2.png

解决OverFitting的方案3.png (35.52 KB, 下载次数: 244)

解决OverFitting的方案3.png

解决OverFitting的方案4.png (46.97 KB, 下载次数: 239)

解决OverFitting的方案4.png

解决OverFitting的方案5.png (121.97 KB, 下载次数: 240)

解决OverFitting的方案5.png

解决OverFitting的方案6.png (107.12 KB, 下载次数: 239)

解决OverFitting的方案6.png

解决OverFitting的方案7.png (55.45 KB, 下载次数: 240)

解决OverFitting的方案7.png

解决OverFitting的方案8.png (355.1 KB, 下载次数: 240)

解决OverFitting的方案8.png

画板 1.png (980.22 KB, 下载次数: 238)

画板 1.png

画板 2.png (419.69 KB, 下载次数: 242)

画板 2.png

画板 3.png (1.12 MB, 下载次数: 237)

画板 3.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
19#
发表于 2021-8-12 12:50:50 | 只看该作者
7777777777777777777777777
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
18#
发表于 2021-8-12 12:50:04 | 只看该作者
99999999999999999999999
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
17#
发表于 2021-8-12 12:49:46 | 只看该作者
8888888888888888888
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
16#
发表于 2021-8-12 12:49:29 | 只看该作者
7777777777777777777
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
15#
发表于 2021-8-12 12:49:11 | 只看该作者
6666666666666666666
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
14#
发表于 2021-8-12 12:48:54 | 只看该作者
555555555555555555555555
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
13#
发表于 2021-8-12 12:48:37 | 只看该作者
4444444444444444444444444444444
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
12#
发表于 2021-8-12 12:47:24 | 只看该作者
33333333333333333333333333
回复

使用道具 举报

0

主题

98

帖子

200

积分

中级会员

Rank: 3Rank: 3

积分
200
11#
发表于 2021-8-12 12:47:04 | 只看该作者
2222222222222222222222
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-5-19 04:38 , Processed in 0.207613 second(s), 22 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表