1、实验室环境下,直接saver和restore即可。
2、生产环境:
(1)部署在移动终端上的(例如ios、android),场景:图像识别等。用freeze_graph合成pb和ckpt文件,然后用optimize_for_inference、quantize_graph进行优化。再用TensorFlowInferenceInterface调用(这个,不知道ios和android是否相同)。
(2)部署在服务端提供服务使用的,场景:推荐系统等。使用tensorflow serving进行模型服务化。
下边是基于部署在服务端提供服务的方式,查阅资料时tensorflow和tensorflow serving都是1.3版本。
在读goole的paper的时候经常看到下边这张图。三个虚框已经把google的系统典型流程描述得很清楚。Data Generation这步,有非常多的学问这里木有经验,略过。我们来看Model Training和Model Serving两部分。也正是题主的问题的核心。
注:整个系统流程都为线上生产流程非实验室环境。
前面几位答友的知识点已经都提到了,这里也就总结整理了下,没有新知识:
1、Previous Models为训练好的模型,即Model Trainer的训练结果。通常在实验室环境中完成一个模型并验证其能发布到线上使用后,通过模型保存扔到生产环境的这里提供给线上系统使用。对应的代码实现:
# Export inference model.
output_path = os.path.join(
tf.compat.as_bytes(FLAGS.output_dir),
tf.compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', output_path
...
builder = tf.saved_model.builder.SavedModelBuilder(output_path)
...
builder.save()
目录里是类似这样的文件:(没什么神秘的,看save的手册即可)
2、Model Trainer,模型训练。只要训练集准备好,就可以对模型进行训练。通常需要有个触发的条件,例如晚上1点,或者数据集抽样完成等,只要能把你的模型运行起来就可以。那这里就涉及两点1)加载Previous Model,2)验证模型,如果满足你的要跟则保存模型。加载模型的代码实现:
# Restore variables from training checkpoint.
variable_averages = tf.train.ExponentialMovingAverage(inception_model.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
3、Model Verifier,不多说,每个模型都要实现的。即accuracy,通常只有accuracy达到我们预计的值为才执行。对应的代码类似:
train_accuracy = accuracy.eval(feed_dict={
x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g"%(i, train_accuracy)
4、关键一步,Model verfierg到Model Servers。模型保存训练并达到我们的要求后,把它保存了下来。因为是生产环境,为了保障线上实时运行的稳定性,需要让训练中的模型和线上系统进行隔离,需要使用model_version+AB分流来解决这个问题。这里就开始用到Tensorflow Serving这个家伙了,即把你的模型给服务化,通过gRPC方式的HTTP提供实时调用。当然,移动端本地化的不需要这样,需要合成pb文件后直接本地调用。
模型服务化的命令:
下载完Tensorflow Serving,编译的命令,具体看官网。
bazel build -c opt //tensorflow_serving/model_servers:tensorflow_model_server
模型服务化,后边那个“/models/mnist_mode”为前边保存模型的目录
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=mnist --model_base_path=/models/mnist_model/
如果能顺利到这步,剩下的事情就是通过9000端口调用你的模型了。
5、使用方调用Model Servers的Clients端,做个gRPC或http发请求调用就可以了。
留言与评论(共有 0 条评论) “” |