详解tensorflow实现迁移学习实例
作者:疯女孩爱飞 发布时间:2022-02-06 01:43:22
本文主要是总结利用tensorflow实现迁移学习的基本步骤。
所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。
持久化
首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。
保存图
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, "model.ckpt")
加载图
saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "model.ckpt")
迁移学习
第一步: 读取加载已经训练好的模型
在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。
def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})
# 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
bottlenect_values = np.squeeze(bottlenect_values)
return bottlenect_values
该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。
第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)
以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持脚本之家。
来源:http://blog.csdn.net/ustbfym/article/details/78201575
猜你喜欢
- 有很多应用项目, 刚起步的时候用MYSQL数据库基本上能实现各种功能需求,随着应用用户的增多,数据量的增加,MYSQL渐渐地出现不堪重负的情
- 前言大家在使用pymysql的时候,通过fetchall()或fetchone()可以获得查询结果,但这个返回数据是不包含字段信息的(不如p
- 代码如下: '排序 Function Sort1(ary) Dim KeepChecking,I,FirstValue,Second
- 最近越来越多在博客上写些UX相关的内容作为分享,就涉及到跟普通博文不一样的文章建构问题。文章内容固然很重要,但排版、组织也是提高可读性和用户
- 请问,如何在ACCESS数据库和SQL SERVER数据库中查询?
- 内容摘要:有很多朋友虽然安装好了mysql但却不知如何使用它。在这篇文章中我们就从连接mysql、修改密码、增加用户等方面来学习一些mysq
- 刚才运行了一段代码,来查看Request.ServerVariables里面有多少值,看了一下,共50个!代码<%=Request.S
- 导语日常开发中,定位程序异常,追溯事件发生场景都需要通过日志记录的方式。可以说一个好的开发日志设计可以让开发人员在后续项目维护的过程中节省时
- 自定义模板403<!DOCTYPE html><html lang="en"><head&
- 接下来我利用一点空余时间发一个函数里面包含和添加和删除功能。实验的架构可以使用IIS.5WEB服务器ACCESS数据库。这个我其实不用说的很
- 也许自己真的就是有手残的毛病,你说好端端的环境配置好了,自己还在那里瞎鼓捣,我最不想看到的就是在安装一个别的模块的时候,自动卸载了本地的其他
- 本文实例讲述了php多进程中的阻塞与非阻塞操作。分享给大家供大家参考,具体如下:我们通过pcntl_fork来创建子进程,使用pcntl_w
- 实例如下所示:#!/usr/bin/python# -*- coding: UTF-8 -*-import reimport urllib,
- 很多小伙伴在学习Django的时候,总是搞不定版本的问题,下面来一起看一张表,轻松解决Python版本和Django版本的兼容问题。Djan
- 我的操作系统为centos6.51 首先选择django要使用什么数据库。django1.10默认数据库为sqlite3,本人想
- 一、安装redis:1.下载:wget http://download.redis.io/releases/redis-3.2.8.tar.
- os.path包os.path包主要用于处理字符串路径,比如'/home/zikong/doc/file.doc',提取出有
- 很实用的过滤重复数据的asp代码,函数如下:<%'**************************************
- 代码如下:<% function GetBot() '查询蜘蛛 dim s_
- icech: 在制作网页的时候,常常要遇到制作虚线表格的问题,下面的文章就能解决这个问题。方法一:作一个1X2的图。半黑半白,再利用表格作成