东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 51581|回复: 125
打印 上一主题 下一主题

[课堂笔记] 02、VGG16预训练好的参数加载与模型网络构建_笔记

  [复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2019-3-12 16:44:19 | 只看该作者 |只看大图 回帖奖励 |倒序浏览 |阅读模式


02、VGG16预训练好的参数加载与模型网络构建_笔记

vgg16与vgg19预训练好的参数文件下载【回复本帖可见】:
游客,如果您要查看本帖隐藏内容请回复

  1. # -*- coding: utf-8 -*-
  2. __author__ = 'dongfangyao'
  3. __date__ = '2019/3/12 下午5:19'
  4. __product__ = 'PyCharm'
  5. __filename__ = '1_vgg16_content'


  6. import numpy as np

  7. vgg16_data = np.load('vgg16.npy', encoding='latin1')

  8. print(type(vgg16_data))
  9. data_dict = vgg16_data.item()
  10. print(type(data_dict))
  11. print(data_dict.keys())
  12. print(len(data_dict))

  13. conv1_1 = data_dict['conv1_1']
  14. print(len(conv1_1))
  15. w, b = conv1_1
  16. print(w.shape)
  17. print(b.shape)

  18. conv3_1 = data_dict['conv3_1']
  19. print(len(conv3_1))
  20. w, b = conv3_1
  21. print(w.shape)
  22. print(b.shape)

  23. fc6 = data_dict['fc6']
  24. w, b = fc6
  25. print(w.shape)
  26. print(b.shape)

  27. fc8 = data_dict['fc8']
  28. w, b = fc8
  29. print(w.shape)
  30. print(b.shape)




复制代码


  1. # -*- coding: utf-8 -*-
  2. __author__ = 'dongfangyao'
  3. __date__ = '2019/3/12 下午5:33'
  4. __product__ = 'PyCharm'
  5. __filename__ = '2_image_style_con'

  6. import tensorflow as tf
  7. from tensorflow import logging
  8. import os
  9. import time
  10. import numpy as np


  11. logging.set_verbosity(logging.INFO)

  12. # logging.info('dfy_88888')
  13. # vgg net 中写死的 归一化的数据预处理
  14. VGG_MEAN = [103.939, 116.779, 123.68]


  15. class VGGNet:
  16.     """
  17.     构建VGG16的网络结构 并从预训练好的模型提取参数 加载
  18.     """
  19.     def __init__(self, data_dict):
  20.         self.data_dict = data_dict

  21.     def get_conv_kernel(self, name):
  22.         # 卷积核的参数:w 0  b 1
  23.         return tf.constant(self.data_dict[name][0], name='conv')

  24.     def get_fc_weight(self, name):
  25.         return tf.constant(self.data_dict[name][0], name='fc')

  26.     def get_bias(self, name):
  27.         return tf.constant(self.data_dict[name][1], name='bias')

  28.     def conv_layer(self, inputs, name):
  29.         """
  30.         构建一个卷积计算层
  31.         :param inputs: 输入的feature_map
  32.         :param name: 卷积层的名字 也是获得参数的key 不能出错
  33.         :return:
  34.         """
  35.         with tf.name_scope(name):
  36.             """
  37.             多使用name_scope的好处:1、防止参数命名冲突 2、tensorboard可视化时很规整
  38.             如果scope里面有变量需要训练时则用tf.variable_scope
  39.             """
  40.             conv_w = self.get_conv_kernel(name)
  41.             conv_b = self.get_bias(name)
  42.             # tf.layers.conv2d() 这是一个封装更高级的api
  43.             # 里面并没有提供接口来输入卷积核参数 这里不能用 平时训练cnn网络时非常好用
  44.             result = tf.nn.conv2d(input=inputs, filter=conv_w, strides=[1, 1, 1, 1], padding='SAME', name=name)
  45.             result = tf.nn.bias_add(result, conv_b)
  46.             result = tf.nn.relu(result)
  47.             return result

  48.     def pooling_layer(self, inputs, name):
  49.         # tf.layers.max_pooling2d()
  50.         # tf.nn.max_pool 这里的池化层没有参数 两套api都可以用
  51.         return tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name=name)

  52.     def fc_layer(self, inputs, name, activation=tf.nn.relu):
  53.         """
  54.         构建全连接层
  55.         :param inputs: 输入
  56.         :param name:
  57.         :param activation: 是否有激活函数的封装
  58.         :return:
  59.         """
  60.         with tf.name_scope(name):
  61.             fc_w = self.get_fc_weight(name)
  62.             fc_b = self.get_bias(name)
  63.             # fc: wx+b 线性变换
  64.             result = tf.nn.bias_add(tf.matmul(inputs, fc_w), fc_b)
  65.             if activation is None:
  66.                 # vgg16的最后是不需relu激活的
  67.                 return result
  68.             else:
  69.                 return activation(result)

  70.     def flatten_op(self, inputs, name):
  71.         # 展平操作 为了后续的fc层必须将维度展平
  72.         with tf.name_scope(name):
  73.             # [NHWC]---> [N, H*W*C]

  74.             x_shape = inputs.get_shape().as_list()
  75.             dim = 1
  76.             for d in x_shape[1:]:
  77.                 dim *= d
  78.             inputs = tf.reshape(inputs, shape=[-1, dim])
  79.             # 直接用现成api也是可以的
  80.             # return tf.layers.flatten(inputs)
  81.             return inputs

  82.     def build(self, input_rgb):
  83.         """
  84.         构建vgg16网络结构 抽取特征 FP过程
  85.         :param input_rgb: [1, 224, 224, 3]
  86.         :return:
  87.         """
  88.         start_time = time.time()
  89.         logging.info('building start...')

  90.         # 在通道维度上分离 深度可分离卷积中也需要用到这个api
  91.         r, g, b = tf.split(input_rgb, num_or_size_splits=3, axis=3)
  92.         # 在通道维度上拼接
  93.         # 输入vgg网络的图像是bgr的(与OpenCV一样 倒序的)而不是rgb
  94.         x_bgr = tf.concat(values=[
  95.             b - VGG_MEAN[0],
  96.             g - VGG_MEAN[1],
  97.             r - VGG_MEAN[2],
  98.         ], axis=3)

  99.         assert x_bgr.get_shape().as_list()[1:] == [224, 224, 3]

  100.         # 构建网络
  101.         # stage 1
  102.         self.conv1_1 = self.conv_layer(x_bgr, 'conv1_1')
  103.         self.conv1_2 = self.conv_layer(self.conv1_1, 'conv1_2')
  104.         self.pool1 = self.pooling_layer(self.conv1_2, 'pool1')

  105.         # stage 2
  106.         self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
  107.         self.conv2_2 = self.conv_layer(self.conv2_1, 'conv2_2')
  108.         self.pool2 = self.pooling_layer(self.conv2_2, 'pool2')

  109.         # stage 3
  110.         self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
  111.         self.conv3_2 = self.conv_layer(self.conv3_1, 'conv3_2')
  112.         self.conv3_3 = self.conv_layer(self.conv3_2, 'conv3_3')
  113.         self.pool3 = self.pooling_layer(self.conv3_3, 'pool3')

  114.         # stage 4
  115.         self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
  116.         self.conv4_2 = self.conv_layer(self.conv4_1, 'conv4_2')
  117.         self.conv4_3 = self.conv_layer(self.conv4_2, 'conv4_3')
  118.         self.pool4 = self.pooling_layer(self.conv4_3, 'pool4')

  119.         # stage 5
  120.         self.conv5_1 = self.conv_layer(self.pool4, 'conv5_1')
  121.         self.conv5_2 = self.conv_layer(self.conv5_1, 'conv5_2')
  122.         self.conv5_3 = self.conv_layer(self.conv5_2, 'conv5_3')
  123.         self.pool5 = self.pooling_layer(self.conv5_3, 'pool5')

  124.         # flatten_op
  125.         self.flatten = self.flatten_op(self.pool5, 'flatten_op')

  126.         # fc
  127.         self.fc6 = self.fc_layer(self.flatten, 'fc6')
  128.         self.fc7 = self.fc_layer(self.fc6, 'fc7')
  129.         self.fc8 = self.fc_layer(self.fc7, 'fc8', activation=None)
  130.         self.logits = tf.nn.softmax(self.fc8, name='logits')
  131.         logging.info('building end... 耗时%3d秒' % (time.time() - start_time))


  132. vgg16_npy_path = './vgg16.npy'
  133. vgg16_data = np.load(vgg16_npy_path, encoding='latin1')

  134. print(type(vgg16_data))
  135. data_dict = vgg16_data.item()

  136. vgg16_for_result = VGGNet(data_dict)
  137. image_rgb = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='image_rgb')
  138. vgg16_for_result.build(image_rgb)
  139. print(vgg16_for_result.conv1_1)
  140. print(vgg16_for_result.fc6)





复制代码







画板 1.png (643.72 KB, 下载次数: 341)

画板 1.png
让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

36

帖子

100

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
100
沙发
发表于 2019-3-13 19:58:24 | 只看该作者
fdsafsadf
回复

使用道具 举报

0

主题

243

帖子

796

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
796
板凳
发表于 2019-3-19 22:49:38 | 只看该作者
fsdfasdfdsa
回复

使用道具 举报

0

主题

12

帖子

60

积分

注册会员

Rank: 2

积分
60
地板
发表于 2019-4-14 22:03:47 | 只看该作者
ssssssssssss
回复

使用道具 举报

0

主题

2

帖子

6

积分

新手上路

Rank: 1

积分
6
5#
发表于 2019-4-17 09:39:04 | 只看该作者
看看大神莫i噢噢噢噢噢噢噢噢
回复

使用道具 举报

0

主题

242

帖子

810

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
810
6#
发表于 2019-4-23 09:58:46 | 只看该作者
飞得高奋斗过
回复

使用道具 举报

0

主题

1

帖子

6

积分

新手上路

Rank: 1

积分
6
7#
发表于 2019-4-29 08:58:59 | 只看该作者
跟着东方老师学习
回复

使用道具 举报

0

主题

205

帖子

460

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
460
8#
发表于 2019-5-26 17:40:17 | 只看该作者
谢谢楼主分享
回复

使用道具 举报

0

主题

364

帖子

864

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
864
9#
发表于 2019-5-29 14:54:42 | 只看该作者
111
回复

使用道具 举报

0

主题

266

帖子

586

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
586
10#
发表于 2019-7-17 12:04:39 | 只看该作者
12121
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-20 10:18 , Processed in 0.198123 second(s), 21 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表