前言:学习tensorflow和深度学习有一段时间了,一直停留在运行别人的代码和跑mnsit和cifar10数据集上,决定从简单的动漫头像生成着手代码,经过无数的debug后终于完成大概,此间主要参考的有以下两个代码,一个是别人写的DCGAN动漫头像生成,另一个是pix2pix的tensorflow实现代码。

动漫头像生成:
https://blog.csdn.net/sinat_33741547/article/details/77871170?locationNum=5&fps=1阿城

<https://blog.csdn.net/sinat_33741547/article/details/77871170?locationNum=5&fps=1%E9%98%BF%E5%9F%8E>

pix2pix代码:
https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py
<https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py>





说明:本部分是数据是数据处理部分,采用的数据是别人提取好的动漫头像,共50000多张,将这些图片转化为tensorflow官方的标准数据TFrecord格式,这个格式的在tensorflow处理的时侯读取速度会快不少

数据来源


百度网盘
<https://pan.baidu.com/s/1eSifHcA?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=>
  密码:g5qa




代码
#!/usr/bin/env python2 # -*- coding: utf-8 -*- '''
读取图片数据并转化为tensorflow官方的TFrecord格式 ''' import tensorflow as tf import os import
sys import time def _int64_feature(value): return
tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def
_bytes_feature(value): return
tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def get_TF():
train_dir = "./faces/" #定义读取图片的路径 data = [] for file in os.listdir(train_dir):
#将图片的路径存储到data list中 data.append(train_dir+file)
stdi,stdo,stde=sys.stdin,sys.stdout,sys.stderr #如果没有这部分会提示编码错误 reload(sys)
#python3的reload在其他包中 sys.setdefaultencoding('utf-8')
sys.stdin,sys.stdout,sys.stderr=stdi,stdo,stde #改正reload之后print输出不了的问题
sess=tf.Session() file_at = 0 start_time = time.time() for i in
range(len(data)): image_path = data[i] #枚举每个图片的路径 image_raw_data =
tf.gfile.FastGFile(image_path,'r').read() img_data =
tf.image.decode_jpeg(image_raw_data,channels=3) #将读取到的图片按照jpeg的格式解压成tensor的形式
img_data = img_data.eval(session=sess) image_raw = img_data.tobytes()
#将图片的tensor变成字符串 example =
tf.train.Example(features=tf.train.Features(feature={ #构造TFrecord形式的example
'height':_int64_feature(img_data.shape[0]),
'width':_int64_feature(img_data.shape[1]),
'channel':_int64_feature(img_data.shape[2]),
'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))
#之后需要的只有'image_raw',其他可以不定义 })) if i % 500 == 0: #500个example存储为一个TFrecord文件
file_at += 1 filename = ("./TFrecord/data-tfrecords-%.5d" % file_at) if i>0:
writer.close() writer = tf.python_io.TFRecordWriter(filename) print("%d
steps,using time %f" % (i,time.time()-start_time)) start_time =time.time()
writer.write(example.SerializeToString()) #将examples写入TFrecord文件 writer.close()
get_TF()
在程序实际运行的时候,一开始处理很快,但是后来生成一个TFrecord文件就越运行越慢,查了资料没发现其他人有出现这个问题,没有解决。当然,也可以直接读取原图片训练。

友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:[email protected]
QQ群:637538335
关注微信