|
04、模型构建之Discriminator判别器的封装_笔记
- def conv2d(inputs, out_channel, name, training):
- def leaky_relu(x, leak=0.2):
- return tf.maximum(x, x*leak, name='leaky_relu')
- with tf.variable_scope(name_or_scope=name):
- conv2d_output = tf.layers.conv2d(inputs, out_channel, [5, 5], (2, 2), padding='SAME')
- bn = tf.layers.batch_normalization(conv2d_output, training=training)
- return leaky_relu(bn)
- class Discriminator:
- def __init__(self, channels):
- self._channels = channels
- self._reuse = False
- def __call__(self, inputs, training):
- # inputs shape : [N 32 32 1]
- inputs = tf.convert_to_tensor(inputs)
- with tf.variable_scope(name_or_scope='discriminator', reuse=self._reuse):
- conv2d_inputs = inputs
- for i in range(len(self._channels)):
- conv2d_inputs = conv2d(conv2d_inputs, self._channels[i], 'conv2d-%d' % i, training)
- fc_inputs = conv2d_inputs
- with tf.variable_scope(name_or_scope='fc'):
- flatten = tf.layers.flatten(fc_inputs, name='flatten')
- logits = tf.layers.dense(flatten, units=2, name='logits')
- self._reuse = True
- self.variables = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
- return logits
复制代码
|
|