东方耀AI技术分享
标题:
04、模型构建之Discriminator判别器的封装_笔记
[打印本页]
作者:
东方耀
时间:
2019-4-25 14:44
标题:
04、模型构建之Discriminator判别器的封装_笔记
04、模型构建之Discriminator判别器的封装_笔记
def conv2d(inputs, out_channel, name, training):
def leaky_relu(x, leak=0.2):
return tf.maximum(x, x*leak, name='leaky_relu')
with tf.variable_scope(name_or_scope=name):
conv2d_output = tf.layers.conv2d(inputs, out_channel, [5, 5], (2, 2), padding='SAME')
bn = tf.layers.batch_normalization(conv2d_output, training=training)
return leaky_relu(bn)
class Discriminator:
def __init__(self, channels):
self._channels = channels
self._reuse = False
def __call__(self, inputs, training):
# inputs shape : [N 32 32 1]
inputs = tf.convert_to_tensor(inputs)
with tf.variable_scope(name_or_scope='discriminator', reuse=self._reuse):
conv2d_inputs = inputs
for i in range(len(self._channels)):
conv2d_inputs = conv2d(conv2d_inputs, self._channels[i], 'conv2d-%d' % i, training)
fc_inputs = conv2d_inputs
with tf.variable_scope(name_or_scope='fc'):
flatten = tf.layers.flatten(fc_inputs, name='flatten')
logits = tf.layers.dense(flatten, units=2, name='logits')
self._reuse = True
self.variables = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
return logits
复制代码
作者:
豆豆888
时间:
2019-4-29 15:12
厉害了 我的哥
欢迎光临 东方耀AI技术分享 (http://www.ai111.vip/)
Powered by Discuz! X3.4