东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2909|回复: 0
打印 上一主题 下一主题

[课堂笔记] 02、手写数字图片数据集的封装_笔记

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2019-4-25 11:44:26 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式


02、手写数字图片数据集的封装_笔记


  1. # -*- coding: utf-8 -*-
  2. __author__ = u'东方耀 微信:dfy_88888'
  3. __date__ = '2019/4/25 11:07'
  4. __product__ = 'PyCharm'
  5. __filename__ = '01_dcgan'

  6. """
  7. dcgan mnist 图片生成问题
  8. 1、data provider 封装dataset (图片 random vector) next_batch(batch_size)
  9. 2、模型构建 数据流图 G D DCGAN汇总G与D
  10. 3、执行训练代码 sess.run 将前面两部融合的过程
  11. """

  12. import os
  13. import sys
  14. import tensorflow as tf
  15. from tensorflow import logging
  16. from tensorflow import gfile
  17. import numpy as np
  18. import pprint
  19. from PIL import Image
  20. from tensorflow.examples.tutorials.mnist import input_data


  21. logging.set_verbosity(logging.INFO)

  22. # 只用到里面的图像 labels无所谓 也就是one_hot随便的
  23. mnist = input_data.read_data_sets(train_dir='mnist_data', one_hot=True)

  24. output_dir = './mnist_data/dcgan_train_output'

  25. if not gfile.Exists(output_dir):
  26.     gfile.MakeDirs(output_dir)

  27. # 样本数 55000
  28. logging.info(mnist.train.num_examples)
  29. # images 特征矩阵 784个特征属性  是28*28像素的
  30. logging.info(mnist.train.images.shape)
  31. # labels 目标属性10个 刚好是10个数字 进行了哑编码one_hot=True
  32. logging.info(mnist.train.labels.shape)


  33. def get_default_params():
  34.     return tf.contrib.training.HParams(
  35.         # 随机向量的长度
  36.         random_vector_size=100,
  37.         # 随机向量变成卷积层的输出格式[NHWC]时候的 图片大小(类似)H=W=4
  38.         init_conv_size=4,
  39.         # 生成器的各个反卷积层的通道数目
  40.         # 为什么最后是1通道而不是3通道 因为使用的数据集是黑白图片是1通道的 而彩色图片才是3通道的
  41.         # g中128的通道是用在随机向量上的 64 32 1就是三个反卷积的通道
  42.         # 经过三个反卷积后 初始的init_conv_size=4 就变成了32的大小了 图片大小 32 = 4*2*2*2 乘3个2即可
  43.         g_channels=[128, 64, 32, 1],
  44.         d_channels=[32, 64, 128, 256],
  45.         batch_size=128,
  46.         learning_rate=0.002,
  47.         # 这是Adam优化器的
  48.         beta1=0.5,
  49.         # 要生成图像的大小 32*32
  50.         img_size=32,

  51.         keep_prob_train=0.8,
  52.         keep_prob_test=1.0,
  53.         # log打印的频率
  54.         sample_log_frequency=50,

  55.     )


  56. hps = get_default_params()
  57. logging.info(hps.img_size)
  58. logging.info(hps.g_channels)


  59. class Mnist_DataSet:
  60.     def __init__(self, mnist_data, random_vector_size, img_size, need_shuffle=True):
  61.         self._img_data = mnist_data
  62.         self._img_data_num = len(self._img_data)
  63.         self._need_shuffle = need_shuffle
  64.         # 为每个图片随机生成对应的 随机向量
  65.         self._random_vector2img_data = np.random.standard_normal(size=(self._img_data_num, random_vector_size))
  66.         self._indicator = 0
  67.         # 需要把训练的图片数据大小变成目标图像的大小 28*28变成32*32
  68.         self._resize_img_data(img_size)
  69.         if self._need_shuffle:
  70.             self._random_shuffle()

  71.     def _resize_img_data(self, img_size):
  72.         """
  73.         调整训练图片的大小到 目标图片的大小
  74.         self._img_data是一个numpy的矩阵 不能直接图片大小缩放
  75.         不能直接将矩阵当图像处理
  76.         1、numpy matrix ----> PIL image 28*28
  77.         2、pil image ----> resize  28--->32
  78.         3、PIL image ----> numpy matrix
  79.         :param img_size: 需要调整的目标图片的大小
  80.         :return:
  81.         """
  82.         # (55000, 784)
  83.         # self._img_data已经作了归一化在0-1之间
  84.         img_data = np.asarray(self._img_data * 255, dtype=np.uint8)
  85.         # 5.5w*784 --->  5.5w*28*28
  86.         img_data = np.reshape(img_data, newshape=(self._img_data_num, 28, 28))
  87.         new_img_data = []
  88.         for i in range(self._img_data_num):
  89.             # img 28*28 numpy matrx
  90.             img = img_data[i]
  91.             # 1、numpy matrix ----> PIL image 28*28
  92.             img = Image.fromarray(img)
  93.             # 2、pil image ----> resize  28--->32
  94.             img = img.resize(size=(img_size, img_size))
  95.             # 3、PIL image ----> numpy matrix
  96.             img = np.asarray(img)
  97.             # [HWC]
  98.             img = np.reshape(img, newshape=(img_size, img_size, 1))
  99.             new_img_data.append(img)
  100.         # 将列表变成numpy的矩阵 大矩阵 5.5w*32*32*1 [NHWC]格式符合CNN网络处理
  101.         # dtype=np.float32 方便下面的归一化
  102.         new_img_data = np.asarray(new_img_data, dtype=np.float32)
  103.         # 0-2  [-1, 1]  shape:[5.5w, 32, 32, 1]
  104.         # 归一化到(-1, 1)之间  与生成器的相契合 G最后是用tanH激活的
  105.         new_img_data = new_img_data / 127.5 - 1
  106.         self._img_data = new_img_data

  107.     def _random_shuffle(self):
  108.         p = np.random.permutation(self._img_data_num)
  109.         self._img_data = self._img_data[p]
  110.         self._random_vector2img_data = self._random_vector2img_data[p]

  111.     def next_batch(self, batch_size):
  112.         end_indicator = self._indicator + batch_size
  113.         if end_indicator > self._img_data_num:
  114.             self._indicator = 0
  115.             self._random_shuffle()
  116.             end_indicator = self._indicator + batch_size
  117.         if end_indicator > self._img_data_num:
  118.             raise Exception('batch_size is too large!')
  119.         batch_img_data = self._img_data[self._indicator: end_indicator]
  120.         batch_random_vector = self._random_vector2img_data[self._indicator: end_indicator]
  121.         # 千万不要忘记了更新指针
  122.         self._indicator = end_indicator
  123.         return batch_img_data, batch_random_vector


  124. mnist_data = Mnist_DataSet(mnist.train.images, hps.random_vector_size, hps.img_size, True)

  125. batch_img_data2, batch_random_vector2 = mnist_data.next_batch(2)
  126. logging.info(batch_img_data2.shape)
  127. pprint.pprint(batch_img_data2[0][16])

  128. logging.info(batch_random_vector2.shape)
  129. pprint.pprint(batch_random_vector2)



复制代码


让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-19 15:33 , Processed in 0.159717 second(s), 18 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表