东方耀AI技术分享

标题: 03、模型构建之Generator生成器的封装_笔记 [打印本页]

作者: 东方耀    时间: 2019-4-25 13:59
标题: 03、模型构建之Generator生成器的封装_笔记


03、模型构建之Generator生成器的封装_笔记



  1. def conv2d_transpose(inputs, out_channel, name, training, with_bn_relu=True):
  2.     with tf.variable_scope(name_or_scope=name):
  3.         conv2d_trans = tf.layers.conv2d_transpose(inputs=inputs, filters=out_channel,
  4.                                                   kernel_size=[5, 5],
  5.                                                   strides=[2, 2], padding='SAME')
  6.         if with_bn_relu:
  7.             bn = tf.layers.batch_normalization(conv2d_trans, training=training)
  8.             relu = tf.nn.relu(bn)
  9.             return relu
  10.         else:
  11.             return conv2d_trans




  12. class Generator:
  13.     def __init__(self, channels, init_conv_size):
  14.         self._channels = channels
  15.         self._init_conv_size = init_conv_size
  16.         self._reuse = False

  17.     def __call__(self, inputs, training):
  18.         # 让类的实例化对象可以 函数一样调用
  19.         # eg: g = Generator(XXX)   g(inputs, training)

  20.         inputs = tf.convert_to_tensor(inputs)
  21.         with tf.variable_scope(name_or_scope='generator', reuse=self._reuse):
  22.             with tf.variable_scope(name_or_scope='inputs_fc'):
  23.                 # inputs shape: [N, 100]
  24.                 fc = tf.layers.dense(inputs=inputs, units=self._init_conv_size * self._init_conv_size * self._channels[0])
  25.                 conv0 = tf.reshape(fc, shape=[-1, self._init_conv_size, self._init_conv_size, self._channels[0]])
  26.                 bn0 = tf.layers.batch_normalization(conv0, training=training)
  27.                 relu0 = tf.nn.relu(bn0)
  28.             # shape: [N 4 4 128]
  29.             conv2d_trans_inputs = relu0
  30.             # g_channels=[128, 64, 32, 1],
  31.             # range(1, 4) ---> 1 2 3
  32.             for i in range(1, len(self._channels)):
  33.                 if i == len(self._channels) - 1:
  34.                     with_bn_relu = False
  35.                 else:
  36.                     with_bn_relu = True
  37.                 conv2d_trans_inputs = conv2d_transpose(conv2d_trans_inputs,
  38.                                                        self._channels[i],
  39.                                                        'conv2d-trans-%d' % i,
  40.                                                        training,
  41.                                                        with_bn_relu)
  42.             # shape: [N 32 32 1]
  43.             image_inputs = conv2d_trans_inputs
  44.             with tf.variable_scope(name_or_scope='generator_image'):
  45.                 # [-1 1]
  46.                 imgs_outputs = tf.nn.tanh(image_inputs, name='imgs_outputs')

  47.         self._reuse = True

  48.         self.variables = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')

  49.         # tf.trainable_variables

  50.         return imgs_outputs

复制代码







欢迎光临 东方耀AI技术分享 (http://www.ai111.vip/) Powered by Discuz! X3.4