东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2838|回复: 1
打印 上一主题 下一主题

[课堂笔记] 58、TensorFlow中RNN(LSTM、GRU) API的使用_整体把控_笔记

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14439
QQ
跳转到指定楼层
楼主
发表于 2019-4-10 11:30:17 | 只看该作者 |只看大图 回帖奖励 |倒序浏览 |阅读模式


58、TensorFlow中RNN(LSTM、GRU) API的使用_整体把控_笔记


  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/4/10 11:40'
  4. __product__ = 'PyCharm'
  5. __filename__ = 'rnn_api'

  6. import tensorflow as tf

  7. """
  8. TensorFlow中和RNN相关的API主要位于两个package:
  9. tf.nn.rnn_cell(主要定义RNN的常见的几种细胞cell Dropout操作)、
  10. tf.nn(RNN相关的计算执行操作)
  11. """
  12. # rnn api
  13. # tf.nn.rnn_cell. 下面放了很多类型的cell
  14. # tf.nn.  辅助的rnn的计算工具

  15. # RNN的中的细胞Cell(BasicRNNCell RNNCell BasicLSTMCell LSTMCell GRUCell MultiRNNCell)
  16. # tf.nn.rnn_cell.BasicRNNCell()
  17. # tf.nn.rnn_cell.RNNCell()
  18. #
  19. # tf.nn.rnn_cell.BasicLSTMCell()
  20. # tf.nn.rnn_cell.LSTMCell()
  21. #
  22. # tf.nn.rnn_cell.GRUCell()
  23. #
  24. # tf.nn.rnn_cell.MultiRNNCell()
  25. #
  26. # tf.nn.rnn_cell.DropoutWrapper()
  27. #
  28. # tf.nn.dynamic_rnn()
  29. # tf.nn.bidirectional_dynamic_rnn()

  30. # 定义第一层的LSTM细胞
  31. cell_1 = tf.nn.rnn_cell.LSTMCell(num_units=128, use_peepholes=False, state_is_tuple=True, name='cell_1')
  32. # num_units:给定一个细胞中的各个神经层次中的神经元数目(状态维度和输出的数据维度和num_units一致)
  33. print(cell_1.output_size)
  34. print(cell_1.state_size)

  35. cell_1 = tf.nn.rnn_cell.DropoutWrapper(cell_1, 0.8)
  36. # 定义第二层的LSTM细胞 BasicLSTMCell与LSTMCell可以一起
  37. cell_2 = tf.nn.rnn_cell.LSTMCell(num_units=256, state_is_tuple=True, name='cell_2')
  38. cell_2 = tf.nn.rnn_cell.DropoutWrapper(cell_2, 0.7)

  39. cells = [cell_1, cell_2]
  40. # 将多层的LSTM合并封装成一个MultiRNNCell 作为单一的cell来处理
  41. cell = tf.nn.rnn_cell.MultiRNNCell(cells=cells, state_is_tuple=True)

  42. print(cell.state_size)
  43. print(cell.output_size)

  44. # 第一种情况:某个时刻(单个时刻)的LSTM计算
  45. # 100表示的是每个批次或每个时刻有100个样本输入 batch_size=100
  46. # 64表示每个样本具有64个特征属性 x1 x2 x3 .... x64
  47. inputs = tf.placeholder(shape=(100, 64), dtype=tf.float32)
  48. initial_state = cell.zero_state(batch_size=100, dtype=tf.float32)
  49. cur_inp, new_states = cell.call(inputs=inputs, state=initial_state)
  50. # cur_inp.shape [batch_size, 最后一层LSTM的细胞中神经元数量]
  51. print(cur_inp)
  52. print(new_states[0])
  53. print(new_states[1])

  54. # 第二种情况:多个时刻(序列式)的LSTM计算
  55. inputs = tf.placeholder(shape=(100, 10, 64), dtype=tf.float32)
  56. # inputs:一组序列(从t=0到t=T), 格式要求:[batch_size, time_steps, input_feature_num]
  57. # 100表示的是每个批次或每个时刻有100个样本输入 batch_size=100
  58. # 第二个维度10表示序列长度(时间长度) 一共有10个时刻  时间长度为10的序列 t1 t2 t3 .... t10
  59. # 64表示每个样本具有64个特征属性 x1 x2 x3 .... x64 每个样本的维度数量
  60. initial_state = cell.zero_state(batch_size=100, dtype=tf.float32)
  61. rnn_outputs, middle_hidden_state = tf.nn.dynamic_rnn(cell=cell, inputs=inputs, initial_state=initial_state)
  62. # rnn_outputs.shape: [batch_size, 时序数, 最后一层LSTM的细胞中神经元数量]
  63. print('rnn_outputs shape:', rnn_outputs.shape)
  64. print('middle_hidden_state[0]:', middle_hidden_state[0])
  65. print('middle_hidden_state[1]:', middle_hidden_state[1])
  66. last = rnn_outputs[:, -1, :]
  67. print('last shape: ', last.shape)


  68. """
  69. tf.nn.dynamic_rnn和tf.nn.static_rnn
  70. 一般用tf.nn.dynamic_rnn
  71. tf.nn.dynamic_rnn:表示在每个批次中动态的构建rnn执行结构,
  72. 可以允许在不同时刻传入的数据的特征维度不同,
  73. eg: 第一时刻传入的数据格式为:[batch_size, 10],
  74. 第二时刻传入的数据格式为:[batch_size, 12],
  75. 第三个时刻传入的数据格式为: [batch_size, 8]......;默认就是填0.
  76. tf.nn.static_rnn: 在网络执行前,就构建好rnn的执行结构,
  77. 要求传入的数据长度必须一致,而且传入的数据必须是tensor的list集合;
  78. 构建的时候比较慢,但是执行相对比较快。

  79. tf.nn.bidirectional_dynamic_rnn 双向动态rnn
  80. tf.nn.static_bidirectional_rnn 双向静态rnn
  81. """
  82. # tf.nn.static_rnn
  83. #
  84. # tf.nn.bidirectional_dynamic_rnn
  85. # tf.nn.static_bidirectional_rnn






复制代码


666.png (276.28 KB, 下载次数: 149)

666.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

96

帖子

198

积分

注册会员

Rank: 2

积分
198
沙发
发表于 2019-7-2 14:06:55 | 只看该作者
多谢分享
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-5-8 01:09 , Processed in 0.187840 second(s), 21 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表