东方耀AI技术分享

标题: 04、模型构建之Discriminator判别器的封装_笔记 [打印本页]

作者: 东方耀    时间: 2019-4-25 14:44
标题: 04、模型构建之Discriminator判别器的封装_笔记


04、模型构建之Discriminator判别器的封装_笔记




  1. def conv2d(inputs, out_channel, name, training):
  2.     def leaky_relu(x, leak=0.2):
  3.         return tf.maximum(x, x*leak, name='leaky_relu')
  4.     with tf.variable_scope(name_or_scope=name):
  5.         conv2d_output = tf.layers.conv2d(inputs, out_channel, [5, 5], (2, 2), padding='SAME')
  6.         bn = tf.layers.batch_normalization(conv2d_output, training=training)
  7.         return leaky_relu(bn)


  8. class Discriminator:
  9.     def __init__(self, channels):
  10.         self._channels = channels
  11.         self._reuse = False

  12.     def __call__(self, inputs, training):
  13.         # inputs shape : [N 32 32 1]
  14.         inputs = tf.convert_to_tensor(inputs)
  15.         with tf.variable_scope(name_or_scope='discriminator', reuse=self._reuse):
  16.             conv2d_inputs = inputs
  17.             for i in range(len(self._channels)):
  18.                 conv2d_inputs = conv2d(conv2d_inputs, self._channels[i], 'conv2d-%d' % i, training)
  19.             fc_inputs = conv2d_inputs
  20.             with tf.variable_scope(name_or_scope='fc'):
  21.                 flatten = tf.layers.flatten(fc_inputs, name='flatten')
  22.                 logits = tf.layers.dense(flatten, units=2, name='logits')
  23.         self._reuse = True
  24.         self.variables = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
  25.         return logits

复制代码



作者: 豆豆888    时间: 2019-4-29 15:12
厉害了 我的哥




欢迎光临 东方耀AI技术分享 (http://www.ai111.vip/) Powered by Discuz! X3.4