东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2918|回复: 0

[课堂笔记] 03、图像风格转换的模型模块分析_笔记

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14427
QQ
发表于 2019-3-13 10:30:25 | 显示全部楼层 |阅读模式


03、图像风格转换的模型模块分析_笔记
模块(一个一个实现来写代码):
   1、定义输入文件与输出目录
   2、管理模型的超参
   3、数据的提供(内容图像 风格图像 随机初始化的图像)
   4、构建计算图(数据流图、定义loss、train_op)
   5、训练执行过程(会话中执行 设备:cpu或gpu或tpu)

图像大小要求:224*224的(vgg net的要求) 原本应该是224*224*3的 由于我是用qq的截图的图片 默认是224*224*4(4通道的 RGBA)所以需要进行处理 请看我的视频操作

  1. # -*- coding: utf-8 -*-
  2. __author__ = 'dongfangyao'
  3. __date__ = '2019/3/13 上午10:33'
  4. __product__ = 'PyCharm'
  5. __filename__ = '2_image_style_conver'

  6. import tensorflow as tf
  7. from tensorflow import logging
  8. import os
  9. from tensorflow import gfile
  10. from PIL import Image
  11. import time
  12. import numpy as np


  13. logging.set_verbosity(logging.INFO)

  14. # logging.info('dfy_88888')
  15. # vgg net 中写死的 归一化的数据预处理
  16. VGG_MEAN = [103.939, 116.779, 123.68]


  17. class VGGNet:
  18.     """
  19.     构建VGG16的网络结构 并从预训练好的模型提取参数 加载
  20.     """
  21.     def __init__(self, data_dict):
  22.         self.data_dict = data_dict

  23.     def get_conv_kernel(self, name):
  24.         # 卷积核的参数:w 0  b 1
  25.         return tf.constant(self.data_dict[name][0], name='conv')

  26.     def get_fc_weight(self, name):
  27.         return tf.constant(self.data_dict[name][0], name='fc')

  28.     def get_bias(self, name):
  29.         return tf.constant(self.data_dict[name][1], name='bias')

  30.     def conv_layer(self, inputs, name):
  31.         """
  32.         构建一个卷积计算层
  33.         :param inputs: 输入的feature_map
  34.         :param name: 卷积层的名字 也是获得参数的key 不能出错
  35.         :return:
  36.         """
  37.         with tf.name_scope(name):
  38.             """
  39.             多使用name_scope的好处:1、防止参数命名冲突 2、tensorboard可视化时很规整
  40.             如果scope里面有变量需要训练时则用tf.variable_scope
  41.             """
  42.             conv_w = self.get_conv_kernel(name)
  43.             conv_b = self.get_bias(name)
  44.             # tf.layers.conv2d() 这是一个封装更高级的api
  45.             # 里面并没有提供接口来输入卷积核参数 这里不能用 平时训练cnn网络时非常好用
  46.             result = tf.nn.conv2d(input=inputs, filter=conv_w, strides=[1, 1, 1, 1], padding='SAME', name=name)
  47.             result = tf.nn.bias_add(result, conv_b)
  48.             result = tf.nn.relu(result)
  49.             return result

  50.     def pooling_layer(self, inputs, name):
  51.         # tf.layers.max_pooling2d()
  52.         # tf.nn.max_pool 这里的池化层没有参数 两套api都可以用
  53.         return tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name=name)

  54.     def fc_layer(self, inputs, name, activation=tf.nn.relu):
  55.         """
  56.         构建全连接层
  57.         :param inputs: 输入
  58.         :param name:
  59.         :param activation: 是否有激活函数的封装
  60.         :return:
  61.         """
  62.         with tf.name_scope(name):
  63.             fc_w = self.get_fc_weight(name)
  64.             fc_b = self.get_bias(name)
  65.             # fc: wx+b 线性变换
  66.             result = tf.nn.bias_add(tf.matmul(inputs, fc_w), fc_b)
  67.             if activation is None:
  68.                 # vgg16的最后是不需relu激活的
  69.                 return result
  70.             else:
  71.                 return activation(result)

  72.     def flatten_op(self, inputs, name):
  73.         # 展平操作 为了后续的fc层必须将维度展平
  74.         with tf.name_scope(name):
  75.             # [NHWC]---> [N, H*W*C]

  76.             x_shape = inputs.get_shape().as_list()
  77.             dim = 1
  78.             for d in x_shape[1:]:
  79.                 dim *= d
  80.             inputs = tf.reshape(inputs, shape=[-1, dim])
  81.             # 直接用现成api也是可以的
  82.             # return tf.layers.flatten(inputs)
  83.             return inputs

  84.     def build(self, input_rgb):
  85.         """
  86.         构建vgg16网络结构 抽取特征 FP过程
  87.         :param input_rgb: [1, 224, 224, 3] [NHWC]
  88.         :return:
  89.         """
  90.         start_time = time.time()
  91.         logging.info('building start...')

  92.         # 在通道维度上分离 深度可分离卷积中也需要用到这个api
  93.         r, g, b = tf.split(input_rgb, num_or_size_splits=3, axis=3)
  94.         # 在通道维度上拼接
  95.         # 输入vgg网络的图像是bgr的(与OpenCV一样 倒序的)而不是rgb
  96.         x_bgr = tf.concat(values=[
  97.             b - VGG_MEAN[0],
  98.             g - VGG_MEAN[1],
  99.             r - VGG_MEAN[2],
  100.         ], axis=3)

  101.         assert x_bgr.get_shape().as_list()[1:] == [224, 224, 3]

  102.         # 构建网络
  103.         # stage 1
  104.         self.conv1_1 = self.conv_layer(x_bgr, 'conv1_1')
  105.         self.conv1_2 = self.conv_layer(self.conv1_1, 'conv1_2')
  106.         self.pool1 = self.pooling_layer(self.conv1_2, 'pool1')

  107.         # stage 2
  108.         self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
  109.         self.conv2_2 = self.conv_layer(self.conv2_1, 'conv2_2')
  110.         self.pool2 = self.pooling_layer(self.conv2_2, 'pool2')

  111.         # stage 3
  112.         self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
  113.         self.conv3_2 = self.conv_layer(self.conv3_1, 'conv3_2')
  114.         self.conv3_3 = self.conv_layer(self.conv3_2, 'conv3_3')
  115.         self.pool3 = self.pooling_layer(self.conv3_3, 'pool3')

  116.         # stage 4
  117.         self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
  118.         self.conv4_2 = self.conv_layer(self.conv4_1, 'conv4_2')
  119.         self.conv4_3 = self.conv_layer(self.conv4_2, 'conv4_3')
  120.         self.pool4 = self.pooling_layer(self.conv4_3, 'pool4')

  121.         # stage 5
  122.         self.conv5_1 = self.conv_layer(self.pool4, 'conv5_1')
  123.         self.conv5_2 = self.conv_layer(self.conv5_1, 'conv5_2')
  124.         self.conv5_3 = self.conv_layer(self.conv5_2, 'conv5_3')
  125.         self.pool5 = self.pooling_layer(self.conv5_3, 'pool5')

  126.         # flatten_op
  127.         # self.flatten = self.flatten_op(self.pool5, 'flatten_op')
  128.         #
  129.         # # fc
  130.         # self.fc6 = self.fc_layer(self.flatten, 'fc6')
  131.         # self.fc7 = self.fc_layer(self.fc6, 'fc7')
  132.         # self.fc8 = self.fc_layer(self.fc7, 'fc8', activation=None)
  133.         # self.logits = tf.nn.softmax(self.fc8, name='logits')
  134.         logging.info('building end... 耗时%3d秒' % (time.time() - start_time))




  135. #
  136. # vgg16_for_result = VGGNet(data_dict)
  137. # image_rgb = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='image_rgb')
  138. # vgg16_for_result.build(image_rgb)
  139. # print(vgg16_for_result.conv1_1)
  140. # print(vgg16_for_result.flatten)
  141. # print(vgg16_for_result.fc6)

  142. """
  143. 模块(一个一个实现来写代码):
  144.    1、定义输入文件与输出目录
  145.    2、管理模型的超参
  146.    3、数据的提供(内容图像 风格图像 随机初始化的图像)
  147.    4、构建计算图(数据流图、定义loss、train_op)
  148.    5、训练执行过程(会话中执行 设备:cpu或gpu或tpu)
  149. """
  150. vgg16_npy_path = './vgg16.npy'
  151. content_img_path = './dfy_88888.png'
  152. style_img_path = './style.png'
  153. output_dir = './output_imgs'


  154. if not gfile.Exists(output_dir):
  155.     gfile.MakeDirs(output_dir)


  156. def get_default_params():
  157.     return tf.contrib.training.HParams(
  158.         learning_rate=10,
  159.         lambda_content_loss=0.05,
  160.         lambda_style_loss=2000,

  161.     )

  162. hps = get_default_params()
  163. print(hps.learning_rate)
  164. print(hps.lambda_content_loss)


  165. def read_img(image_name):
  166.     img = Image.open(image_name)
  167.     np_img = np.array(img)
  168.     # [224, 224, 4]
  169.     print('np_img shape:', np_img.shape, image_name)
  170.     # RGBA--->RGB [224, 224, 3]
  171.     np_img = np_img[:, :, 0:3]
  172.     np_img = np.asarray([np_img], dtype=np.float32)
  173.     print('np_img shape: ', np_img.shape)
  174.     # (1, 224, 224, 3)
  175.     return np_img


  176. # read_img(content_img_path)
  177. content_img_arr_val = read_img(content_img_path)
  178. style_img_arr_val = read_img(style_img_path)


  179. def initial_image(shape, mean, stddev):
  180.     # 截断的随机的正态分布 数据产生
  181.     initial_img = tf.truncated_normal(shape=shape, mean=mean, stddev=stddev, dtype=tf.float32)
  182.     return tf.Variable(initial_value=initial_img, trainable=True)

  183. result_img_val = initial_image([1, 224, 224, 3], mean=255//2, stddev=20)
  184. # 用占位符 具体的值在sess中通过feed_dict喂养 后面执行阶段会有
  185. content_img = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='content_img')
  186. style_img = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='style_img')


  187. # 提取图像的卷积层的特征

  188. vgg16_data = np.load(vgg16_npy_path, encoding='latin1')
  189. data_dict = vgg16_data.item()

  190. vgg16_for_result_img = VGGNet(data_dict)
  191. vgg16_for_content_img = VGGNet(data_dict)
  192. vgg16_for_style_img = VGGNet(data_dict)

  193. # 结果图像vgg16的构建
  194. vgg16_for_result_img.build(result_img_val)
  195. # 内容图像vgg16的构建
  196. vgg16_for_content_img.build(content_img)
  197. # 风格图像vgg16的构建
  198. vgg16_for_style_img.build(style_img)


  199. # 定义需要提取哪些层的特征了 cnn

  200. # 内容图像的内容特征抽取 越低层效果越好
  201. content_features = [
  202.     vgg16_for_content_img.conv1_1,
  203.     vgg16_for_content_img.conv2_1,
  204.     # vgg16_for_content_img.conv3_1,
  205.     # vgg16_for_content_img.conv3_2,
  206.     # vgg16_for_content_img.conv5_1,
  207.     # vgg16_for_content_img.conv5_3,
  208. ]

  209. # 结果图像的内容特征抽取 必须一致
  210. result_content_features = [
  211.     vgg16_for_result_img.conv1_1,
  212.     vgg16_for_result_img.conv2_1,
  213.     # vgg16_for_result_img.conv3_1,
  214.     # vgg16_for_result_img.conv3_2,
  215.     # vgg16_for_result_img.conv5_1,
  216.     # vgg16_for_result_img.conv5_3,
  217. ]

  218. # 风格图像的风格特征抽取 越高层越好
  219. style_features = [
  220.     # vgg16_for_style_img.conv1_1,
  221.     # vgg16_for_style_img.conv2_1,
  222.     # vgg16_for_style_img.conv3_1,
  223.     # vgg16_for_style_img.conv4_2,
  224.     vgg16_for_style_img.conv4_3,
  225.     vgg16_for_style_img.conv5_3,
  226. ]

  227. # 结果图像的风格特征抽取 必须一致
  228. result_style_features = [
  229.     # vgg16_for_result_img.conv1_1,
  230.     # vgg16_for_result_img.conv2_1,
  231.     # vgg16_for_result_img.conv3_1,
  232.     # vgg16_for_result_img.conv4_2,
  233.     vgg16_for_result_img.conv4_3,
  234.     vgg16_for_result_img.conv5_3,
  235. ]

  236. # loss = loss_content + loss_style






























复制代码



style.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-17 01:55 , Processed in 0.195697 second(s), 22 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表