keras模型保存为tensorflow的二进制模型方式
作者:Eileng 发布时间:2022-12-21 07:20:08
最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型。
折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:
# coding=utf-8
import sys
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a prunned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
prunned so subgraphs that are not neccesary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
input_fld = sys.path[0]
weight_file = 'your_model.h5'
output_graph_name = 'tensor_model.pb'
output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)
K.set_learning_phase(0)
net_model = load_model(weight_file_path)
print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
上面代码实现保存到当前目录的tensor_model目录下。
验证:
import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tensors = tf.import_graph_def(output_graph_def, name="")
print tensors
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
op = sess.graph.get_operations()
for m in op:
print(m.values())
input_x = sess.graph.get_tensor_by_name("convolution2d_1_input:0") #具体名称看上一段代码的input.name
print input_x
out_softmax = sess.graph.get_tensor_by_name("activation_4/Softmax:0") #具体名称看上一段代码的output.name
print out_softmax
img = cv2.imread(jpg_path, 0)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: 1.0 - np.array(img).reshape((-1,28, 28, 1)) / 255.0})
print "img_out_softmax:", img_out_softmax
prediction_labels = np.argmax(img_out_softmax, axis=1)
print "label:", prediction_labels
pb_path = 'tensorflow_model/constant_graph_weights.pb'
img = 'test/6/8_48.jpg'
recognize(img, pb_path)
补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用
首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件
模型载入是通过 my_model = keras . models . load_model( filepath )
要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:
# -*- coding: utf-8 -*-
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.layers import Dropout
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential,load_model
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
import collections
from collections import defaultdict
import jieba
import numpy as np
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file = 'biLSTM_brand_recognize.model'
output_graph_name = 'tensor_model_v3.pb'
output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)
K.set_learning_phase(0)
net_model = load_model(weight_file_path)
print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)
print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
然后模型就存成了.pb格式的文件
问题就来了,这样存下来的.pb格式的文件是frozen model
如果通过TensorFlow serving 启用模型的话,会报错:
E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`
因为TensorFlow serving 希望读取的是saved model
于是需要将frozen model 转化为 saved model 格式,解决方案如下:
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name(net_model.input.name)
out = g.get_tensor_by_name(net_model.output.name)
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"in": inp}, {"out": out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()
于是保存下来的saved model 文件夹下就有两个文件:
saved_model.pb variables
其中variables 可以为空
于是将.pb 模型导入serving再读取,成功!
来源:https://blog.csdn.net/liuguangsuiyue/article/details/78298602
猜你喜欢
- Python应用编程需要用到的针对不同数据库引擎的数据库接口:http://wiki.python.org/moin/DatabaseInt
- 数组:【重点1】implode(分隔,arr) 把数组值数据按指定字符连接起来例如:$arr=array('1','
- 列表UL或是OL中都有一个预设标记,这个标记可能是实点圆点,也可能是数字。在实际的应用中,我们需要去掉这个预设标记,但我们不清楚这个预设标记
- 这是份总结,有不恰达的地方欢迎一同讨论联系方式 : 龙藏 longzang@taobao.com点击这里全幅围观或者点下面大图去 slide
- DesktopNexus 是我最喜爱的一个壁纸下载网站,上面有许多高质量的壁纸,几乎每天必上, 每月也必会坚持分享我这个月来收集的壁纸但是
- 先想创意,再画草图,接着鼠绘,最后做成flas * 。这是我的习惯流程。 这是想到中秋时,我第一时间内能浮想出的图像:大意是嫦娥奔
- LoadRunner监控MySQLhttp://www.docin.com/p-92272846.htmlAdvanced MySQL Pe
- (一)问题遗传算法求解正方形拼图游戏(二)代码#!/usr/bin/env python# -*- coding: utf-8 -*-fro
- 锁定数据库的一个表 SELECT * FROM table WITH (HOLDLOCK) 注意: 锁定数据库的一个表的区别 SELECT
- Google中秋的logo出来了,酷似一美男站在月亮上,结果被网友弄出一撒尿版来。中国网民好智慧啊~原logo: 撒尿版logo:
- 本文实例讲述了Go语言使用sort包对任意类型元素的集合进行排序的方法。分享给大家供大家参考。具体如下:使用sort包的函数进行排序时,集合
- 微软今天发布了SQL Server 2005 SP3的正式版,而这也将是该软件的最后一次升级服务,不过暂时只有英文版本,需要简体中文版的用户
- 什么是Inception ResnetV2Inception ResnetV2是Inception ResnetV1的一个加强版,两者的结构
- 很多人在群里问,这个下拉框定位不到、那个弹出框定位不到…各种定位不到,其实大多数情况下就是两种问题:1 有frame,2 没有加等待。殊不知
- 一、HACK以下两种方法几乎能解决现今所有HACK。1, !important 随着IE7对!important的支持, !imp
- 1.包: package PaintBrush; /** * * @author lucifer */ public class Paint
- 显示有限的接口到外部当发布python第三方package时,并不希望代码中所有的函数或者class可以被外部import,在__init_
- 本文实例讲述了JS简单实现DIV相对于浏览器固定位置不变的方法。分享给大家供大家参考,具体如下:<!DOCTYPE HTML PUBL
- 在Vista IIS 7 中用 vs2005 调试 Web 项目核心是要解决以下几个问题:1、Vista 自身在安全性方面的User Acc
- 在本文中我们将展示一种新的使用仿CSS选择器的语法来快速开发HTML和CSS的方法。它由Sergey Chikuyonok开发。你在写HTM