东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2852|回复: 0
打印 上一主题 下一主题

[课堂笔记] 03、模型构建之Generator生成器的封装_笔记

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2019-4-25 13:59:35 | 只看该作者 |只看大图 回帖奖励 |倒序浏览 |阅读模式


03、模型构建之Generator生成器的封装_笔记



  1. def conv2d_transpose(inputs, out_channel, name, training, with_bn_relu=True):
  2.     with tf.variable_scope(name_or_scope=name):
  3.         conv2d_trans = tf.layers.conv2d_transpose(inputs=inputs, filters=out_channel,
  4.                                                   kernel_size=[5, 5],
  5.                                                   strides=[2, 2], padding='SAME')
  6.         if with_bn_relu:
  7.             bn = tf.layers.batch_normalization(conv2d_trans, training=training)
  8.             relu = tf.nn.relu(bn)
  9.             return relu
  10.         else:
  11.             return conv2d_trans




  12. class Generator:
  13.     def __init__(self, channels, init_conv_size):
  14.         self._channels = channels
  15.         self._init_conv_size = init_conv_size
  16.         self._reuse = False

  17.     def __call__(self, inputs, training):
  18.         # 让类的实例化对象可以 函数一样调用
  19.         # eg: g = Generator(XXX)   g(inputs, training)

  20.         inputs = tf.convert_to_tensor(inputs)
  21.         with tf.variable_scope(name_or_scope='generator', reuse=self._reuse):
  22.             with tf.variable_scope(name_or_scope='inputs_fc'):
  23.                 # inputs shape: [N, 100]
  24.                 fc = tf.layers.dense(inputs=inputs, units=self._init_conv_size * self._init_conv_size * self._channels[0])
  25.                 conv0 = tf.reshape(fc, shape=[-1, self._init_conv_size, self._init_conv_size, self._channels[0]])
  26.                 bn0 = tf.layers.batch_normalization(conv0, training=training)
  27.                 relu0 = tf.nn.relu(bn0)
  28.             # shape: [N 4 4 128]
  29.             conv2d_trans_inputs = relu0
  30.             # g_channels=[128, 64, 32, 1],
  31.             # range(1, 4) ---> 1 2 3
  32.             for i in range(1, len(self._channels)):
  33.                 if i == len(self._channels) - 1:
  34.                     with_bn_relu = False
  35.                 else:
  36.                     with_bn_relu = True
  37.                 conv2d_trans_inputs = conv2d_transpose(conv2d_trans_inputs,
  38.                                                        self._channels[i],
  39.                                                        'conv2d-trans-%d' % i,
  40.                                                        training,
  41.                                                        with_bn_relu)
  42.             # shape: [N 32 32 1]
  43.             image_inputs = conv2d_trans_inputs
  44.             with tf.variable_scope(name_or_scope='generator_image'):
  45.                 # [-1 1]
  46.                 imgs_outputs = tf.nn.tanh(image_inputs, name='imgs_outputs')

  47.         self._reuse = True

  48.         self.variables = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')

  49.         # tf.trainable_variables

  50.         return imgs_outputs

复制代码


4.png (335.88 KB, 下载次数: 307)

4.png

44.png (329.52 KB, 下载次数: 314)

44.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-20 04:55 , Processed in 0.187115 second(s), 21 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表