前几篇博客写了如何处理数据,如何把用自己的数据训练VGG-16,如何把训练好的模型保存。而在实际应用中,并不是所有的操作都是为了分类的,有时候需要提取图像的特征,那么怎么利用已经保存的模型提取特征呢?

   “桃叶儿尖上尖,柳叶儿就遮满了天”

    测试数据转换成tfrecords,教程:点击打开链接
<https://blog.csdn.net/gybheroin/article/details/79800679>

    保存训练好的VGG-16模型,教程:点击打开链接
<https://blog.csdn.net/gybheroin/article/details/79806096>

1、读取测试数据

      首先把测试数据转换成tfrecords,然后读取出来,代码和前面博客写的一致:      
#读取文件 def read_and_decode(filename,batch_size): #根据文件名生成一个队列 filename_queue =
tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _,
serialized_example = reader.read(filename_queue) #返回文件名和文件 features =
tf.parse_single_example(serialized_example, features={ 'label':
tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([],
tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img =
tf.reshape(img, [300, 300, 3]) #图像归一化大小 # img = tf.cast(img, tf.float32) * (1.
/ 255) - 0.5 #图像减去均值处理 label = tf.cast(features['label'], tf.int32) #特殊处理
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=
batch_size, num_threads=64, capacity=2000, min_after_dequeue=1500) return
img_batch, tf.reshape(label_batch,[batch_size])
2、调取保存的训练好的VGG-16模型



    最核心的部分是使用saver类中的restore方法,核心代码如下:
saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta") #注意路径
saver.restore(sess, "./model/checkpoint/model.ckpt") #保存模型的路径
3、把测试数据传进去模型提取特征

   
 利用的是graph.get_tensor_by_name(“名字”),则首先获取模型中占位符,然后将测试数据传进去,这是最核心的地方,想要提取特征也是通过名字获取张量,比如要提取fc7的特征,则fc7_features=graph.get_tensor_by_name("fc7:0")
。核心代码如下:
graph = tf.get_default_graph() #获取恢复模型的图模型 x_holder =
graph.get_tensor_by_name("x_holder:0") # 获取占位符
fc7_features=graph.get_tensor_by_name("fc7:0") #获取要提取的特征,用该层的名字
keep_prob=graph.get_tensor_by_name("keep_prob:0") #同上 # 通过张量的名称来获取张量
print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout}))
#给占位符重新赋值,则可以提取输入图像的特征
4、完整的代码

   
 整个过程,博主用了好几天的时间才调通,中间的心酸历程就不多说了,直接放完整的提取特征代码吧,如果想用保存的模型做分类,而不是提特征,则举一反三,我觉得并不难,修改一下即可:

完整代码:
# -*- coding: utf-8 -*- """ Created on Mon Apr 2 17:12:00 2018 @author: Heroin
高永标,upc """ import tensorflow as tf #读取文件 def
read_and_decode(filename,batch_size): #根据文件名生成一个队列 filename_queue =
tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _,
serialized_example = reader.read(filename_queue) #返回文件名和文件 features =
tf.parse_single_example(serialized_example, features={ 'label':
tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([],
tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img =
tf.reshape(img, [300, 300, 3]) #图像归一化大小 # img = tf.cast(img, tf.float32) * (1.
/ 255) - 0.5 #图像减去均值处理 label = tf.cast(features['label'], tf.int32) #特殊处理
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=
batch_size, num_threads=64, capacity=2000, min_after_dequeue=1500) return
img_batch, tf.reshape(label_batch,[batch_size]) batch_size=4 dropout=1.0
tfrecords_file = 'train.tfrecords' #保存的测试数据 BATCH_SIZE = 4 image_batch,
label_batch = read_and_decode(tfrecords_file,BATCH_SIZE) #print(image_batch)
#sess=tf.InteractiveSession() with tf.Session() as sess: coord =
tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess,coord
= coord) image,label=sess.run([image_batch,label_batch]) saver =
tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta") #保存的模型路径
saver.restore(sess, "./model/checkpoint/model.ckpt") graph =
tf.get_default_graph() x_holder = graph.get_tensor_by_name("x_holder:0") #
获取占位符 fc7_features=graph.get_tensor_by_name("fc7:0") #获取要提取的特征,用名字
keep_prob=graph.get_tensor_by_name("keep_prob:0") # 通过张量的名称来获取张量
print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout}))
#给占位符重新赋值 sess.close()

友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:[email protected]
QQ群:637538335
关注微信