东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 3625|回复: 1

[课堂笔记] TensorFlow实现图像风格转换 V1算法

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
发表于 2019-2-5 21:08:24 | 显示全部楼层 |阅读模式



TensorFlow实现图像风格转换 V1算法实现


  1. # -*- coding: utf-8 -*-
  2. __author__ = 'dongfangyao'
  3. __date__ = '2019/2/4 下午3:33'
  4. __product__ = 'PyCharm'
  5. __filename__ = '10_image_style_conversion'
  6. """
  7.    TensorFlow实现图像风格转换 V1算法
  8.    v1算法的缺点:每次需要随机初始化图像的变量 使用GD下降 运行多次 效率低
  9.    v1算法训练的是 图像本身
  10.    v2算法训练的是 网络(Image Transform Net)
  11.    v3算法:重新定义了风格损失的计算方法 放弃了Gram矩阵计算相似度 用了最match小块分割方法(patch为单位)
  12. """
  13. import os
  14. import numpy as np
  15. import tensorflow as tf
  16. import time
  17. from PIL import Image
  18. import matplotlib.pyplot as plt

  19. VGG_MEAN = [103.939, 116.779, 123.68]


  20. class VGGNet:
  21.     """
  22.     构建VGG16的网络结构 并从预处理模型中加载训练好的参数
  23.     """
  24.     def __init__(self, data_dict):
  25.         self.data_dict = data_dict

  26.     def get_conv_kernel(self, name):
  27.         # 卷积核的参数 w:0 bias:1
  28.         return tf.constant(self.data_dict[name][0], name='conv')

  29.     def get_fc_weight(self, name):
  30.         return tf.constant(self.data_dict[name][0], name='fc')

  31.     def get_bias(self, name):
  32.         return tf.constant(self.data_dict[name][1], name='bias')

  33.     def conv_layer(self, inputs, name):
  34.         """构建一个卷积计算层"""
  35.         # 多使用name_scope的好处:1、防止参数命名冲突 2、可视化的显示规整
  36.         with tf.name_scope(name):
  37.             conv_w = self.get_conv_kernel(name)
  38.             conv_b = self.get_bias(name)
  39.             # tf.layers.conv2d() 里面没有参数的 不能用了
  40.             result = tf.nn.conv2d(inputs, conv_w, [1, 1, 1, 1], padding='SAME')
  41.             result = tf.nn.bias_add(result, conv_b)
  42.             result = tf.nn.relu(result)
  43.             return result

  44.     def pooling_layer(self, inputs, name):
  45.         """构建一个池化层 tf.layers.max_pooling2d()"""
  46.         result = tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name=name)
  47.         return result

  48.     def fc_layer(self, inputs, name, activation=tf.nn.relu):
  49.         """构建全连接层的计算"""
  50.         with tf.name_scope(name):
  51.             fc_w = self.get_fc_weight(name)
  52.             fc_b = self.get_bias(name)
  53.             result = tf.nn.bias_add(tf.matmul(inputs, fc_w), fc_b)
  54.             if activation is None:
  55.                 return result
  56.             else:
  57.                 return activation(result)

  58.     def flatten_op(self, inputs, name):
  59.         """展平操作 tf.layers.flatten()"""
  60.         with tf.name_scope(name):
  61.             # [N H W C]---> [N H*W*C]
  62.             x_shape = inputs.get_shape().as_list()
  63.             dim = 1
  64.             for d in x_shape[1:]:
  65.                 dim *= d
  66.             inputs = tf.reshape(inputs, shape=[-1, dim])
  67.             return inputs

  68.     def build(self, input_rgb):
  69.         """构建vgg16网络 提取特征 FP过程
  70.            参数:
  71.            --input_rgb: [1, 224, 224, 3]
  72.         """
  73.         start_time = time.time()
  74.         print('building start...')

  75.         # r, g, b = tf.split(value=input_rgb, num_or_size_splits=[1, 1, 1], axis=3)
  76.         r, g, b = tf.split(value=input_rgb, num_or_size_splits=3, axis=3)
  77.         # 输入vgg的图像是bgr的顺序(跟opencv一样 倒序的)而不是rgb
  78.         x_bgr = tf.concat(values=[
  79.             b - VGG_MEAN[0],
  80.             g - VGG_MEAN[1],
  81.             r - VGG_MEAN[2]
  82.         ], axis=3)

  83.         assert x_bgr.get_shape().as_list()[1:] == [224, 224, 3]
  84.         # stage 1
  85.         self.conv1_1 = self.conv_layer(x_bgr, 'conv1_1')
  86.         self.conv1_2 = self.conv_layer(self.conv1_1, 'conv1_2')
  87.         self.pool1 = self.pooling_layer(self.conv1_2, 'pool1')
  88.         # stage 2
  89.         self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
  90.         self.conv2_2 = self.conv_layer(self.conv2_1, 'conv2_2')
  91.         self.pool2 = self.pooling_layer(self.conv2_2, 'pool2')

  92.         # stage 3
  93.         self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
  94.         self.conv3_2 = self.conv_layer(self.conv3_1, 'conv3_2')
  95.         self.conv3_3 = self.conv_layer(self.conv3_2, 'conv3_3')
  96.         self.pool3 = self.pooling_layer(self.conv3_3, 'pool3')
  97.         # stage 4
  98.         self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
  99.         self.conv4_2 = self.conv_layer(self.conv4_1, 'conv4_2')
  100.         self.conv4_3 = self.conv_layer(self.conv4_2, 'conv4_3')
  101.         self.pool4 = self.pooling_layer(self.conv4_3, 'pool4')
  102.         # stage 5
  103.         self.conv5_1 = self.conv_layer(self.pool4, 'conv5_1')
  104.         self.conv5_2 = self.conv_layer(self.conv5_1, 'conv5_2')
  105.         self.conv5_3 = self.conv_layer(self.conv5_2, 'conv5_3')
  106.         self.pool5 = self.pooling_layer(self.conv5_3, 'pool5')
  107.         # stage 6
  108.         # self.flatten5 = self.flatten_op(self.pool5, 'flatten5')
  109.         # self.fc6 = self.fc_layer(self.flatten5, 'fc6')
  110.         # self.fc7 = self.fc_layer(self.fc6, 'fc7')
  111.         # self.fc8 = self.fc_layer(self.fc7, 'fc8', activation=None)
  112.         # self.prob = tf.nn.softmax(self.fc8, name='prob')

  113.         # print(self.prob.shape)

  114.         print('buliding finished 耗时:%4ds' % (time.time() - start_time))
  115.         pass

  116. vgg16_npy_path = './vgg16.npy'
  117. # vgg16_data = np.load(vgg16_npy_path, encoding='latin1')
  118. # data_dict = vgg16_data.item()
  119. #
  120. # print(data_dict.keys())
  121. #
  122. # vgg16_for_result = VGGNet(data_dict)
  123. #
  124. # image_rgb = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='image_rgb')
  125. #
  126. # vgg16_for_result.build(image_rgb)

  127. # 224 * 224
  128. content_img_path = './img/content.jpeg'
  129. style_img_path = './img/style.jpeg'

  130. num_steps = 100
  131. learning_rate = 10

  132. lambda_c = 0.1
  133. lambda_s = 500

  134. output_dir = './img/output_img'

  135. if not os.path.exists(output_dir):
  136.     os.mkdir(output_dir)


  137. def initial_image(shape, mean, stddev):
  138.     initial_image = tf.truncated_normal(shape=shape, mean=mean, stddev=stddev, dtype=tf.float32)
  139.     return tf.Variable(initial_image)


  140. initial_image_result = initial_image([1, 224, 224, 3], mean=127.5, stddev=20)


  141. def read_img(image_name):
  142.     img = Image.open(image_name)
  143.     # [224 224 3]
  144.     np_img = np.array(img)
  145.     # np_img = tf.reshape(np_img, shape=[1, 224, 224, 3])
  146.     # [1 224 224 3]
  147.     # np_img = np.reshape(np_img, newshape=[1, 224, 224, 3])
  148.     np_img = np.asarray([np_img], dtype=np.float32)
  149.     print(np_img.shape)
  150.     return np_img

  151. # plt.imshow(read_img(style_img_path)[0])
  152. # plt.show()


  153. content_img_arr_val = read_img(content_img_path)
  154. style_img_arr_val = read_img(style_img_path)

  155. content_img = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='content_img')
  156. style_img = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='style_img')

  157. # initial_image_result content_img style_img 三张图片进入vgg网络 提取特征
  158. vgg16_data = np.load(vgg16_npy_path, encoding='latin1')
  159. data_dict = vgg16_data.item()

  160. print(data_dict.keys())

  161. vgg16_for_initial_result = VGGNet(data_dict)
  162. vgg16_for_content_img = VGGNet(data_dict)
  163. vgg16_for_style_img = VGGNet(data_dict)

  164. vgg16_for_content_img.build(content_img)
  165. vgg16_for_style_img.build(style_img)
  166. vgg16_for_initial_result.build(initial_image_result)

  167. # 定义提取哪些层的特征 cnn的

  168. # 内容特征 越低层越精细
  169. content_features = [
  170.     vgg16_for_content_img.conv1_2,
  171.     vgg16_for_content_img.conv2_2,
  172.     # vgg16_for_content_img.conv3_3,
  173.     # vgg16_for_content_img.conv4_3,
  174.     # vgg16_for_content_img.conv5_3
  175. ]

  176. # 结果的内容特征必须与内容特征一致
  177. result_content_features = [
  178.     vgg16_for_initial_result.conv1_2,
  179.     vgg16_for_initial_result.conv2_2,
  180.     # vgg16_for_initial_result.conv3_3,
  181.     # vgg16_for_initial_result.conv4_3,
  182.     # vgg16_for_initial_result.conv5_3
  183. ]

  184. # 风格特征 越高层越抽象
  185. style_features = [
  186.     # vgg16_for_style_img.conv1_2,
  187.     # vgg16_for_style_img.conv2_2,
  188.     # vgg16_for_style_img.conv3_3,
  189.     vgg16_for_style_img.conv4_3,
  190.     vgg16_for_style_img.conv5_3
  191. ]

  192. # 结果的风格特征必须与风格特征一致
  193. result_style_features = [
  194.     # vgg16_for_initial_result.conv1_2,
  195.     # vgg16_for_initial_result.conv2_2,
  196.     # vgg16_for_initial_result.conv3_3,
  197.     vgg16_for_initial_result.conv4_3,
  198.     vgg16_for_initial_result.conv5_3
  199. ]

  200. # 开始计算损失
  201. content_loss = tf.zeros(shape=1, dtype=tf.float32)
  202. # zip([1, 2], [3, 4]) ---> [(1, 3), (2, 4)] 两个数组变成 一个数组
  203. # c与c_的shape(卷积激励之后):[1 height width channel] 在通道axis=[1, 2, 3]求平均
  204. # loss = mse 平方差损失函数
  205. for c, c_ in zip(content_features, result_content_features):
  206.     content_loss += tf.reduce_mean(tf.square(c - c_), axis=[1, 2, 3])
  207.     pass

  208. # 计算Gram矩阵
  209. def gram_matrix(x):
  210.     """Gram矩阵计算 k个feature_map 两两之间的关联性 相似度 k*k的矩阵
  211.     Args:
  212.     ---x feature_map from Conv层 shape:[1 height width channels]    """
  213.     b, h ,w, ch = x.get_shape().as_list()
  214.     features = tf.reshape(x, shape=[b, h*w, ch])
  215.     # [ch ch] = [ch h*w] 矩阵相乘 [h*w ch]
  216.     # 除以维度相乘是为了防止 值过大 除以统一的数
  217.     # features是三维的 导致gram矩阵也是三维的
  218.     gram = tf.matmul(features, features, adjoint_a=True) / tf.constant(h*w*ch, tf.float32)
  219.     # features是2维的 导致gram矩阵也是2维的
  220.     # gram = tf.matmul(tf.matrix_transpose(features[0]), features[0]) / tf.constant(h*w*ch, tf.float32)
  221.     return gram
  222.     pass


  223. style_gram_matrix = [gram_matrix(feature) for feature in style_features]
  224. result_style_gram_matrix = [gram_matrix(feature) for feature in result_style_features]

  225. style_loss = tf.zeros(shape=1, dtype=tf.float32)
  226. for s, s_ in zip(style_gram_matrix, result_style_gram_matrix):
  227.     # loss = mse 平方差损失函数
  228.     # axis=[0, 1, 2, 3] 这里会报错  gram矩阵输出是三维的
  229.     style_loss += tf.reduce_mean(tf.square(s - s_), axis=[1, 2])
  230.     # gram矩阵输出是二维的
  231.     # style_loss += tf.reduce_mean(tf.square(s - s_), axis=[0, 1])
  232.     pass

  233. loss = content_loss * lambda_c + style_loss * lambda_s

  234. with tf.name_scope('train'):
  235.     train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)


  236. init_op = tf.global_variables_initializer()

  237. with tf.Session() as sess:
  238.     sess.run(init_op)
  239.     for step in range(num_steps):
  240.         loss_value, content_loss_value, style_loss_value, _ = \
  241.         sess.run(fetches=[loss, content_loss, style_loss, train_op], feed_dict={
  242.             content_img: content_img_arr_val,
  243.             style_img: style_img_arr_val
  244.         })
  245.         print('step:%d loss_value:%8.4f content_loss_value:%8.4f style_loss_value:%8.4f'
  246.               % (step+1, loss_value[0], content_loss_value[0], style_loss_value[0]))
  247.         # result_image:shape [224, 224, 3]
  248.         result_image = initial_image_result.eval(sess)[0]
  249.         # np.clip值裁剪 小于0的变为0 大于255的变为255
  250.         result_image = np.clip(result_image, 0, 255)
  251.         result_image = np.asarray(result_image, dtype=np.uint8)
  252.         # np_img = np.asarray([np_img], dtype=np.float32)
  253.         img = Image.fromarray(result_image)
  254.         result_image_path = os.path.join(output_dir, 'result-%05d.jpg' % (step + 1))
  255.         img.save(result_image_path)

  256.     pass



复制代码


让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

117

帖子

258

积分

中级会员

Rank: 3Rank: 3

积分
258
QQ
发表于 2020-2-3 15:45:34 | 显示全部楼层
谢谢老师提供的资料。
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-18 19:30 , Processed in 2.569997 second(s), 19 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表