TensorFlow将已训练模型保存并预测

        第一次接触机器学习,训练和测试网上都有大把的教程,但是当我需要使用模型进行预测的时候,却遇到了一些麻烦,在这里就把自己的一部分总结写下来。注意这里是将训练好的神经网络拿出来进行预测(使用),而不是测试,测试可与训练在一个文件中进行,当然也可像这样载入网络和训练数据来测试。

        一般来说,TensorFlow对数据集的训练和测试都会同时进行,但是在实际中使用的时候,受限于实际使用算法的机器性能,我们不会重复进行训练,而是将TensorFlow在高性能平台训练后的模型保存下来,只在边缘设备上加载模型并进行预测,大大减少了对机器性能的需求。

        在完成培训之后,我们希望将所有的变量和网络结构保存到一个文件中,以便将来使用。因此,在Tensorflow中,我们希望保存所有参数的图和值,我们将创建一个tf.train.Saver()类的实例。

tf.flags.DEFINE_integer('num_epoch', 1500, 'The number of epoches for training.')
tf.flags.DEFINE_string('checkpoint', './checkpoint/', 'the checkpoint dir')
tf.flags.DEFINE_string('model_name', 'model4.ckpt', 'model name')
FLAGS = tf.flags.FLAGS

def train():
    ······
    saver = tf.train.Saver(tf.global_variables())  # 实例化 saver 对象
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        ······
        for i in range(FLAGS.num_epoch):
            ······
            # 对模型的每一次迭代,都保存到ckpt文件中
            saver.save(sess, os.path.join(FLAGS.checkpoint, FLAGS.model_name), global_step=i + 1)
            # 保存计算图,如果需要无须再次训练
            tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.checkpoint, 'graph.pb', as_text=False)

        在使用模型进行预测时,将模型恢复,通过feed_dict={······}传递输入参数,参数列表为训练时添加的tf.placeholder等。执行完毕后,通过pred = tf.argmax(Ylogits, 1)函数得到最终分类结果。

注:训练时在输入输出参数时的placeholder()最好增加name参数,为节点命名

def predict():
    ······
    saver = tf.train.Saver(tf.global_variables())  # 实例化 saver 对象
    # 预测, Ylogits为模型计算图函数
    pred = tf.argmax(Ylogits, 1)

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint)
        # 加载指定路径下的ckpt, 若模型存在, 则加载模型到当前对话
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        ······
        y_pred = np.empty((testdata.shape[0], FLAGS.num_classes), dtype=np.float32)
        for i in range(testdata.shape[0]):
            y_pred[i, :] = sess.run(pred, feed_dict={······})
        print(y_pred)

 

发表评论

电子邮件地址不会被公开。 必填项已用*标注