东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2847|回复: 0

[课堂笔记] 04、定义图像生成问题的内容损失与风格损失_笔记

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14418
QQ
发表于 2019-3-13 16:40:28 | 显示全部楼层 |阅读模式


04、定义图像生成问题的内容损失与风格损失_笔记

  1. # -*- coding: utf-8 -*-
  2. __author__ = 'dongfangyao'
  3. __date__ = '2019/3/13 上午10:33'
  4. __product__ = 'PyCharm'
  5. __filename__ = '2_image_style_conver'

  6. import tensorflow as tf
  7. from tensorflow import logging
  8. import os
  9. from tensorflow import gfile
  10. from PIL import Image
  11. import time
  12. import numpy as np


  13. logging.set_verbosity(logging.INFO)

  14. # logging.info('dfy_88888')
  15. # vgg net 中写死的 归一化的数据预处理
  16. VGG_MEAN = [103.939, 116.779, 123.68]


  17. class VGGNet:
  18.     """
  19.     构建VGG16的网络结构 并从预训练好的模型提取参数 加载
  20.     """
  21.     def __init__(self, data_dict):
  22.         self.data_dict = data_dict

  23.     def get_conv_kernel(self, name):
  24.         # 卷积核的参数:w 0  b 1
  25.         return tf.constant(self.data_dict[name][0], name='conv')

  26.     def get_fc_weight(self, name):
  27.         return tf.constant(self.data_dict[name][0], name='fc')

  28.     def get_bias(self, name):
  29.         return tf.constant(self.data_dict[name][1], name='bias')

  30.     def conv_layer(self, inputs, name):
  31.         """
  32.         构建一个卷积计算层
  33.         :param inputs: 输入的feature_map
  34.         :param name: 卷积层的名字 也是获得参数的key 不能出错
  35.         :return:
  36.         """
  37.         with tf.name_scope(name):
  38.             """
  39.             多使用name_scope的好处:1、防止参数命名冲突 2、tensorboard可视化时很规整
  40.             如果scope里面有变量需要训练时则用tf.variable_scope
  41.             """
  42.             conv_w = self.get_conv_kernel(name)
  43.             conv_b = self.get_bias(name)
  44.             # tf.layers.conv2d() 这是一个封装更高级的api
  45.             # 里面并没有提供接口来输入卷积核参数 这里不能用 平时训练cnn网络时非常好用
  46.             result = tf.nn.conv2d(input=inputs, filter=conv_w, strides=[1, 1, 1, 1], padding='SAME', name=name)
  47.             result = tf.nn.bias_add(result, conv_b)
  48.             result = tf.nn.relu(result)
  49.             return result

  50.     def pooling_layer(self, inputs, name):
  51.         # tf.layers.max_pooling2d()
  52.         # tf.nn.max_pool 这里的池化层没有参数 两套api都可以用
  53.         return tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name=name)

  54.     def fc_layer(self, inputs, name, activation=tf.nn.relu):
  55.         """
  56.         构建全连接层
  57.         :param inputs: 输入
  58.         :param name:
  59.         :param activation: 是否有激活函数的封装
  60.         :return:
  61.         """
  62.         with tf.name_scope(name):
  63.             fc_w = self.get_fc_weight(name)
  64.             fc_b = self.get_bias(name)
  65.             # fc: wx+b 线性变换
  66.             result = tf.nn.bias_add(tf.matmul(inputs, fc_w), fc_b)
  67.             if activation is None:
  68.                 # vgg16的最后是不需relu激活的
  69.                 return result
  70.             else:
  71.                 return activation(result)

  72.     def flatten_op(self, inputs, name):
  73.         # 展平操作 为了后续的fc层必须将维度展平
  74.         with tf.name_scope(name):
  75.             # [NHWC]---> [N, H*W*C]

  76.             x_shape = inputs.get_shape().as_list()
  77.             dim = 1
  78.             for d in x_shape[1:]:
  79.                 dim *= d
  80.             inputs = tf.reshape(inputs, shape=[-1, dim])
  81.             # 直接用现成api也是可以的
  82.             # return tf.layers.flatten(inputs)
  83.             return inputs

  84.     def build(self, input_rgb):
  85.         """
  86.         构建vgg16网络结构 抽取特征 FP过程
  87.         :param input_rgb: [1, 224, 224, 3] [NHWC]
  88.         :return:
  89.         """
  90.         start_time = time.time()
  91.         logging.info('building start...')

  92.         # 在通道维度上分离 深度可分离卷积中也需要用到这个api
  93.         r, g, b = tf.split(input_rgb, num_or_size_splits=3, axis=3)
  94.         # 在通道维度上拼接
  95.         # 输入vgg网络的图像是bgr的(与OpenCV一样 倒序的)而不是rgb
  96.         x_bgr = tf.concat(values=[
  97.             b - VGG_MEAN[0],
  98.             g - VGG_MEAN[1],
  99.             r - VGG_MEAN[2],
  100.         ], axis=3)

  101.         assert x_bgr.get_shape().as_list()[1:] == [224, 224, 3]

  102.         # 构建网络
  103.         # stage 1
  104.         self.conv1_1 = self.conv_layer(x_bgr, 'conv1_1')
  105.         self.conv1_2 = self.conv_layer(self.conv1_1, 'conv1_2')
  106.         self.pool1 = self.pooling_layer(self.conv1_2, 'pool1')

  107.         # stage 2
  108.         self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
  109.         self.conv2_2 = self.conv_layer(self.conv2_1, 'conv2_2')
  110.         self.pool2 = self.pooling_layer(self.conv2_2, 'pool2')

  111.         # stage 3
  112.         self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
  113.         self.conv3_2 = self.conv_layer(self.conv3_1, 'conv3_2')
  114.         self.conv3_3 = self.conv_layer(self.conv3_2, 'conv3_3')
  115.         self.pool3 = self.pooling_layer(self.conv3_3, 'pool3')

  116.         # stage 4
  117.         self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
  118.         self.conv4_2 = self.conv_layer(self.conv4_1, 'conv4_2')
  119.         self.conv4_3 = self.conv_layer(self.conv4_2, 'conv4_3')
  120.         self.pool4 = self.pooling_layer(self.conv4_3, 'pool4')

  121.         # stage 5
  122.         self.conv5_1 = self.conv_layer(self.pool4, 'conv5_1')
  123.         self.conv5_2 = self.conv_layer(self.conv5_1, 'conv5_2')
  124.         self.conv5_3 = self.conv_layer(self.conv5_2, 'conv5_3')
  125.         self.pool5 = self.pooling_layer(self.conv5_3, 'pool5')

  126.         # flatten_op
  127.         # self.flatten = self.flatten_op(self.pool5, 'flatten_op')
  128.         #
  129.         # # fc
  130.         # self.fc6 = self.fc_layer(self.flatten, 'fc6')
  131.         # self.fc7 = self.fc_layer(self.fc6, 'fc7')
  132.         # self.fc8 = self.fc_layer(self.fc7, 'fc8', activation=None)
  133.         # self.logits = tf.nn.softmax(self.fc8, name='logits')
  134.         logging.info('building end... 耗时%3d秒' % (time.time() - start_time))




  135. #
  136. # vgg16_for_result = VGGNet(data_dict)
  137. # image_rgb = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='image_rgb')
  138. # vgg16_for_result.build(image_rgb)
  139. # print(vgg16_for_result.conv1_1)
  140. # print(vgg16_for_result.flatten)
  141. # print(vgg16_for_result.fc6)

  142. """
  143. 模块(一个一个实现来写代码):
  144.    1、定义输入文件与输出目录
  145.    2、管理模型的超参
  146.    3、数据的提供(内容图像 风格图像 随机初始化的图像)
  147.    4、构建计算图(数据流图、定义loss、train_op)
  148.    5、训练执行过程(会话中执行 设备:cpu或gpu或tpu)
  149. """
  150. vgg16_npy_path = './vgg16.npy'
  151. content_img_path = './dfy_88888.png'
  152. style_img_path = './style.png'
  153. output_dir = './output_imgs'


  154. if not gfile.Exists(output_dir):
  155.     gfile.MakeDirs(output_dir)


  156. def get_default_params():
  157.     return tf.contrib.training.HParams(
  158.         learning_rate=10,
  159.         lambda_content_loss=0.05,
  160.         lambda_style_loss=2000,

  161.     )

  162. hps = get_default_params()
  163. print(hps.learning_rate)
  164. print(hps.lambda_content_loss)


  165. def read_img(image_name):
  166.     img = Image.open(image_name)
  167.     np_img = np.array(img)
  168.     # [224, 224, 4]
  169.     print('np_img shape:', np_img.shape, image_name)
  170.     # RGBA--->RGB [224, 224, 3]
  171.     np_img = np_img[:, :, 0:3]
  172.     np_img = np.asarray([np_img], dtype=np.float32)
  173.     print('np_img shape: ', np_img.shape)
  174.     # (1, 224, 224, 3)
  175.     return np_img


  176. # read_img(content_img_path)
  177. content_img_arr_val = read_img(content_img_path)
  178. style_img_arr_val = read_img(style_img_path)


  179. def initial_image(shape, mean, stddev):
  180.     # 截断的随机的正态分布 数据产生
  181.     initial_img = tf.truncated_normal(shape=shape, mean=mean, stddev=stddev, dtype=tf.float32)
  182.     return tf.Variable(initial_value=initial_img, trainable=True)

  183. result_img_val = initial_image([1, 224, 224, 3], mean=255//2, stddev=20)
  184. # 用占位符 具体的值在sess中通过feed_dict喂养 后面执行阶段会有
  185. content_img = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='content_img')
  186. style_img = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='style_img')


  187. # 提取图像的卷积层的特征

  188. vgg16_data = np.load(vgg16_npy_path, encoding='latin1')
  189. data_dict = vgg16_data.item()

  190. vgg16_for_result_img = VGGNet(data_dict)
  191. vgg16_for_content_img = VGGNet(data_dict)
  192. vgg16_for_style_img = VGGNet(data_dict)

  193. # 结果图像vgg16的构建
  194. vgg16_for_result_img.build(result_img_val)
  195. # 内容图像vgg16的构建
  196. vgg16_for_content_img.build(content_img)
  197. # 风格图像vgg16的构建
  198. vgg16_for_style_img.build(style_img)


  199. # 定义需要提取哪些层的特征了 cnn

  200. # 内容图像的内容特征抽取 越低层效果越好
  201. # shape: [NWHC]  卷积层经过激励之后的输出 feature_map
  202. content_features = [
  203.     vgg16_for_content_img.conv1_1,
  204.     vgg16_for_content_img.conv2_1,
  205.     # vgg16_for_content_img.conv3_1,
  206.     # vgg16_for_content_img.conv3_2,
  207.     # vgg16_for_content_img.conv5_1,
  208.     # vgg16_for_content_img.conv5_3,
  209. ]

  210. # 结果图像的内容特征抽取 必须一致
  211. result_content_features = [
  212.     vgg16_for_result_img.conv1_1,
  213.     vgg16_for_result_img.conv2_1,
  214.     # vgg16_for_result_img.conv3_1,
  215.     # vgg16_for_result_img.conv3_2,
  216.     # vgg16_for_result_img.conv5_1,
  217.     # vgg16_for_result_img.conv5_3,
  218. ]

  219. # 风格图像的风格特征抽取 越高层越好
  220. style_features = [
  221.     # vgg16_for_style_img.conv1_1,
  222.     # vgg16_for_style_img.conv2_1,
  223.     # vgg16_for_style_img.conv3_1,
  224.     # vgg16_for_style_img.conv4_2,
  225.     vgg16_for_style_img.conv4_3,
  226.     vgg16_for_style_img.conv5_3,
  227. ]

  228. # 结果图像的风格特征抽取 必须一致
  229. result_style_features = [
  230.     # vgg16_for_result_img.conv1_1,
  231.     # vgg16_for_result_img.conv2_1,
  232.     # vgg16_for_result_img.conv3_1,
  233.     # vgg16_for_result_img.conv4_2,
  234.     vgg16_for_result_img.conv4_3,
  235.     vgg16_for_result_img.conv5_3,
  236. ]



  237. # loss = loss_content + loss_style


  238. content_loss = tf.zeros(shape=1, dtype=tf.float32)
  239. # shape [2, 2] ---> [[0, 0],[0, 0]]
  240. # shape 2  ---> [0, 0]
  241. # shape 1 ----> [0 ]
  242. # shape 0 ----> 0  标量
  243. # mse 平方差损失函数
  244. # zip([1, 2], [3, 4])  ---> [(1, 3), (2, 4)]
  245. for c, c_result in zip(content_features, result_content_features):
  246.     # c c_result [NHWC]
  247.     content_loss += tf.reduce_mean(tf.square(c - c_result), axis=[1, 2, 3])


  248. # Gram矩阵 得到关联性的度量
  249. def gram_matrix(x):
  250.     """
  251.     Gram矩阵的计算  k个feature_map 两两之间的关联性 相似度的计算 k*k的矩阵
  252.     k=channels
  253.     :param x: shape [NHWC] [1, height, width, channels]
  254.     :return:
  255.     """
  256.     batch_size, h, w, c = x.get_shape().as_list()
  257.     # features shape [b, h*w, c]
  258.     features = tf.reshape(x, shape=[batch_size, h*w, c])
  259.     # [c, c] = [c, h*w] 矩阵乘法 [h*w, c]
  260.     # features[0] shape : [h*w, c]
  261.     gram = tf.matmul(tf.matrix_transpose(features[0]), features[0]) / tf.constant(h*w*c, tf.float32)
  262.     # gram shape : [c, c]  [k, k]
  263.     return gram

  264. # 列表生成式
  265. style_gram_matrix = [gram_matrix(feature) for feature in style_features]
  266. result_style_gram_matrix = [gram_matrix(feature) for feature in result_style_features]

  267. style_loss = tf.zeros(shape=1, dtype=tf.float32)

  268. for s, s_result in zip(style_gram_matrix, result_style_gram_matrix):
  269.     style_loss += tf.reduce_mean(tf.square(s - s_result), axis=[0, 1])

  270. # 最终的损失函数loss由两部分组成:内容损失与风格损失的加权和
  271. loss = hps.lambda_content_loss * content_loss + hps.lambda_style_loss * style_loss

  272. with tf.name_scope('train_op'):
  273.     train_op = tf.train.AdamOptimizer(hps.learning_rate).minimize(loss)



复制代码

让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-3-29 01:34 , Processed in 0.176149 second(s), 19 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表