通过freeze_graph.py将TensorFlow模型转换为FrozenGraph并进行预测

        在训练过程中,tf.train.Saver()通常不会将权重数据保存在计算图中,反而是分开保存在checkpoint检查点文件里,当模型初始化时,通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,在产品使用时为了恢复模型需要额外计算,所以便有了freeze_graph.py脚本文件用来将计算图和权重合二为一,减少对边缘设备的性能要求。

  • checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
  • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被tf.train.import_meta_graph 加载到当前默认的图来使用。
  • ckpt.data:保存模型中每个变量的取值
  • pb文件能够保存TensorFlow计算图中的操作节点以及对应的各张量,方便我们日后直接调用之前已经训练好的计算图

        ckpt检查点文件与pb文件的生成请看TensorFlow将已训练模型保存并预测

        网上的教程大多使用了Bazel,但是我在尝试的时候遇到了一些困难,后来发现可以使用python,由于每个人TensorFlow安装的位置也不太一样,就可以直接在Pycharm中调用,代码如下:

from tensorflow.python.tools import freeze_graph

freeze_graph.freeze_graph(input_graph='graph.pb',  # 要转换的计算图
                          input_saver="",  # 为空即可
                          input_binary=True,  # 传递pb文件,则为True
                          input_checkpoint='model4.ckpt-1486',  # 输入的权重,不需要最后的后缀
                          output_graph='freeze_graph.pb',  # 输出的冻结图
                          output_node_names="MatMul_2,output_pred",  # 计算图的输出节点,详见下文
                          restore_op_name='',  # 为空即可
                          filename_tensor_name='',  # 为空即可
                          clear_devices=True,  # 为平台设备的兼容性,清除默认设备
                          initializer_nodes=''  # 为空即可)
  • input_graph:模型文件,可以是二进制的pb文件,或文本的pbtxt文件,用input_binary来指定区分
  • input_saver:TensorFlow Saver file
  • input_binary:配合input_graph用,为true时,input_graph格式为.pb,为false时,input_graph格式为.pbtxt
  • input_checkpoint:检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。
  • output_node_names:输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。
  • restore_op_name:已弃用
  • filename_tensor_name:已弃用
  • output_graph:用来保存整合后的模型输出文件。
  • clear_devices:指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)
  • initializer_nodes:权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
  1. 若计算图的输出节点训练时已经在placeholder中添加name参数,则易于确定输出节点的名字
  2. 若未添加name参数,则可以通过TensorBoard确定输出节点,TensorBoard可通过tensorboard --logdir=C:\Users\logs运行,其中,logs文件夹内包含日志文件(日志文件由wirter = tf.summary.FileWriter('logs/',ss.graph)语句产生)

       在生成Frozen_Graph后,便可以通过pb文件直接加载模型,无需额外计算

 

 

def predict():
    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())

        # 定义输入的张量名称,对应网络结构的输入张量
        # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
        input_data_tensor = sess.graph.get_tensor_by_name("IteratorGetNext:0")
        input_keep_prob_tensor = sess.graph.get_tensor_by_name("input_keep_prob:0")
        input_is_training_tensor = sess.graph.get_tensor_by_name("input_Y:0")

        # 定义输出的张量名称
        output_ylogits_tensor = sess.graph.get_tensor_by_name("MatMul_2:0")
        output_pred_tensor = sess.graph.get_tensor_by_name("output_pred:0")

        # 定义输出的存储名称
        y_pred = np.empty((testdata.shape[0], FLAGS.num_classes), dtype=np.float32)
        y_test = np.empty((testdata.shape[0], FLAGS.num_classes), dtype=np.float32)

        array = []  # 转为整型列表
        for i in range(testdata.shape[0]):
            y_pred[i, :], y_test[i, :] = sess.run([output_pred_tensor, output_ylogits_tensor],
                                                  feed_dict={input_data_tensor: testdata[i][np.newaxis, :],
                                                             input_keep_prob_tensor: 1,
                                                             input_is_training_tensor: False})
            array.append(int(y_pred[i, 0]))
        print(y_pred)

 

 

 

发表评论

电子邮件地址不会被公开。 必填项已用*标注