很多人入门tensorflow时,都会做mnist,cifar10的分类实例,但无法用自己的训练集做分类,本教程用tfhub工具完成图像分类的训练和测试。
1. 安装tensorflow和tfhub
$ pip install "tensorflow>=1.7.0" $ pip install tensorflow-hub
你需要安装最新的tensorflow 1.7版本,并安装hub,因为hub要求tensorflow版本高于1.7。
2.训练需要的tensorflow-hub模型
该模块实际上是一个保存的模型。它包含预先训练的权重和图形。它是可重复使用的,可重新加工的。它将算法以图形和权重的形式进行打包。
您可以找到所有新发布的图像模块的列表 <https://www.tensorflow.org/hub/modules/image>。
其中一些包括分类层,其中一些删除它们只是提供一个特征向量作为输出。我们将选择一个特征向量模块Inception V1。
3.创建你自己的训练数据集
在开始训练之前,您需要把每一类的图片分别存在不同的文件夹中。为了使训练更好地工作,您应该至少收集您想要识别的每个类别的一百张照片。
整理training set
您有一个包含类名子文件夹的文件夹,每个文件夹中的每个文件夹都有完整的图像。 示例文件夹水果应该有这样的结构:
以下是文件夹结构:
~/fruits/apple/photo1.jpg ~/fruits/orange/photo2.jpg ...
~/fruits/banana/photo77.jpg ... ~/fruits/apple/someone.jpg
4.开始训练
用我们自己的数据集去训练hub的module,最后能得到graph和txt。
python retrain.py --image_dir your data_dir \ --saved_model_dir your
saved_model_dir \ --bottleneck_dir your bottleneck_dir \
--how_many_training_steps 4000 \ --output_labels output/output_labels.txt \
--output_graph output/retrain_graph.pb
该脚本加载预先训练的模块并在水果照片上retrain一个新的分类器。您可以用任何包含子文件夹的文件夹替换image_dir参数
图片。 每个图像的标签都取自它所在的子文件夹的名称。
5.使用tensorboard观察
我们可以将图表和统计数据可视化,例如训练期间权重或准确度如何变化。
在训练期间或之后运行tensorboard命令。
tensorboard --logdir /tmp/retrain_logs
在浏览器中输入:”你的主机:6006”
6.测试
retrain.py
<https://github.com/tensorflow/hub/tree/master/examples/image_retraining>
这个脚本将在您的类别上训练过的新模型写入/tmp/output_graph.pb(这个是default,你可以修改这个文件的名字和地址,如上面的retrain_graph.pb),并将包含标签的文本文件写入/tmp/output_labels.txt。
新模型包含新的分类层。
由于您已替换顶层,因此您需要在脚本中指定新名称。
以下是如何在训练图中运行label_image示例的示例,这个脚本在github里的tensorflow
<https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image>
python label_image.py \--graph=/output/retrain_graph.pb \--labels=/output/
output_labels.txt \--input_layer=Placeholder \--output_layer=final_result \--
image=~ yourtest.jpg_dir
7.更换训练模型
该脚本默认使用Inception V3模型体系结构。如果要在移动平台上部署,可以尝试-tfhub_module命令使用Mobilenet模型。
python retrain.py --image_dir ~/fruit_photos \ --tfhub_module
https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/feature_vector/1
参考文献
https://www.tensorflow.org/tutorials/image_retraining
热门工具 换一换