|
02、手写数字图片数据集的封装_笔记
- # -*- coding: utf-8 -*-
- __author__ = u'东方耀 微信:dfy_88888'
- __date__ = '2019/4/25 11:07'
- __product__ = 'PyCharm'
- __filename__ = '01_dcgan'
- """
- dcgan mnist 图片生成问题
- 1、data provider 封装dataset (图片 random vector) next_batch(batch_size)
- 2、模型构建 数据流图 G D DCGAN汇总G与D
- 3、执行训练代码 sess.run 将前面两部融合的过程
- """
- import os
- import sys
- import tensorflow as tf
- from tensorflow import logging
- from tensorflow import gfile
- import numpy as np
- import pprint
- from PIL import Image
- from tensorflow.examples.tutorials.mnist import input_data
- logging.set_verbosity(logging.INFO)
- # 只用到里面的图像 labels无所谓 也就是one_hot随便的
- mnist = input_data.read_data_sets(train_dir='mnist_data', one_hot=True)
- output_dir = './mnist_data/dcgan_train_output'
- if not gfile.Exists(output_dir):
- gfile.MakeDirs(output_dir)
- # 样本数 55000
- logging.info(mnist.train.num_examples)
- # images 特征矩阵 784个特征属性 是28*28像素的
- logging.info(mnist.train.images.shape)
- # labels 目标属性10个 刚好是10个数字 进行了哑编码one_hot=True
- logging.info(mnist.train.labels.shape)
- def get_default_params():
- return tf.contrib.training.HParams(
- # 随机向量的长度
- random_vector_size=100,
- # 随机向量变成卷积层的输出格式[NHWC]时候的 图片大小(类似)H=W=4
- init_conv_size=4,
- # 生成器的各个反卷积层的通道数目
- # 为什么最后是1通道而不是3通道 因为使用的数据集是黑白图片是1通道的 而彩色图片才是3通道的
- # g中128的通道是用在随机向量上的 64 32 1就是三个反卷积的通道
- # 经过三个反卷积后 初始的init_conv_size=4 就变成了32的大小了 图片大小 32 = 4*2*2*2 乘3个2即可
- g_channels=[128, 64, 32, 1],
- d_channels=[32, 64, 128, 256],
- batch_size=128,
- learning_rate=0.002,
- # 这是Adam优化器的
- beta1=0.5,
- # 要生成图像的大小 32*32
- img_size=32,
- keep_prob_train=0.8,
- keep_prob_test=1.0,
- # log打印的频率
- sample_log_frequency=50,
- )
- hps = get_default_params()
- logging.info(hps.img_size)
- logging.info(hps.g_channels)
- class Mnist_DataSet:
- def __init__(self, mnist_data, random_vector_size, img_size, need_shuffle=True):
- self._img_data = mnist_data
- self._img_data_num = len(self._img_data)
- self._need_shuffle = need_shuffle
- # 为每个图片随机生成对应的 随机向量
- self._random_vector2img_data = np.random.standard_normal(size=(self._img_data_num, random_vector_size))
- self._indicator = 0
- # 需要把训练的图片数据大小变成目标图像的大小 28*28变成32*32
- self._resize_img_data(img_size)
- if self._need_shuffle:
- self._random_shuffle()
- def _resize_img_data(self, img_size):
- """
- 调整训练图片的大小到 目标图片的大小
- self._img_data是一个numpy的矩阵 不能直接图片大小缩放
- 不能直接将矩阵当图像处理
- 1、numpy matrix ----> PIL image 28*28
- 2、pil image ----> resize 28--->32
- 3、PIL image ----> numpy matrix
- :param img_size: 需要调整的目标图片的大小
- :return:
- """
- # (55000, 784)
- # self._img_data已经作了归一化在0-1之间
- img_data = np.asarray(self._img_data * 255, dtype=np.uint8)
- # 5.5w*784 ---> 5.5w*28*28
- img_data = np.reshape(img_data, newshape=(self._img_data_num, 28, 28))
- new_img_data = []
- for i in range(self._img_data_num):
- # img 28*28 numpy matrx
- img = img_data[i]
- # 1、numpy matrix ----> PIL image 28*28
- img = Image.fromarray(img)
- # 2、pil image ----> resize 28--->32
- img = img.resize(size=(img_size, img_size))
- # 3、PIL image ----> numpy matrix
- img = np.asarray(img)
- # [HWC]
- img = np.reshape(img, newshape=(img_size, img_size, 1))
- new_img_data.append(img)
- # 将列表变成numpy的矩阵 大矩阵 5.5w*32*32*1 [NHWC]格式符合CNN网络处理
- # dtype=np.float32 方便下面的归一化
- new_img_data = np.asarray(new_img_data, dtype=np.float32)
- # 0-2 [-1, 1] shape:[5.5w, 32, 32, 1]
- # 归一化到(-1, 1)之间 与生成器的相契合 G最后是用tanH激活的
- new_img_data = new_img_data / 127.5 - 1
- self._img_data = new_img_data
- def _random_shuffle(self):
- p = np.random.permutation(self._img_data_num)
- self._img_data = self._img_data[p]
- self._random_vector2img_data = self._random_vector2img_data[p]
- def next_batch(self, batch_size):
- end_indicator = self._indicator + batch_size
- if end_indicator > self._img_data_num:
- self._indicator = 0
- self._random_shuffle()
- end_indicator = self._indicator + batch_size
- if end_indicator > self._img_data_num:
- raise Exception('batch_size is too large!')
- batch_img_data = self._img_data[self._indicator: end_indicator]
- batch_random_vector = self._random_vector2img_data[self._indicator: end_indicator]
- # 千万不要忘记了更新指针
- self._indicator = end_indicator
- return batch_img_data, batch_random_vector
- mnist_data = Mnist_DataSet(mnist.train.images, hps.random_vector_size, hps.img_size, True)
- batch_img_data2, batch_random_vector2 = mnist_data.next_batch(2)
- logging.info(batch_img_data2.shape)
- pprint.pprint(batch_img_data2[0][16])
- logging.info(batch_random_vector2.shape)
- pprint.pprint(batch_random_vector2)
复制代码
|
|