JavaWeb+TensorFlow图像预测分类(教程+填坑)

JavaWeb+TensorFlow图像预测分类(教程+填坑)

JavaWeb+TensorFlow图像预测分类

    首先还是先向我学习路上的前辈致谢:

https://blog.csdn.net/Light_Dream/article/details/88813227

https://github.com/amir-abdi/keras_to_tensorflow

  大家都知道,现在进行深度学习和图像识别最流行的平台就是TensorFlow了,基本上也都是Python+TensorFlow进行项目搭建,但现在大多数web端仍然还是Java编写的,将深度学习模型移植到Java平台成了一个需要解决的问题。当前国内介绍这个的移植工作还是比较少的,现存的教程也有一堆坑,踩了一堆后总算是摸索出些东西,分享给大家。

平台:Java 8+Tomcat+TensorFlow

第一步:固化模型为pb文件

  将需要移植的模型保存为二进制pb文件,主要的点是output_node_names数组,该数据的名称表示需要保存的tensorflow tensor名。既是在python中定义模型时指定的计算操作的name。填写什么就保存到什么节点。在cnn模型中,通常是分类输出的名称。

  由于我图像识别模型是用keras做的,所以还需要固化为Pb模型,这里强烈推荐一个github上的转化脚本:

https://github.com/amir-abdi/keras_to_tensorflow

第二步:在Java上搭建tensorflow环境

  这个地方非常重要,需要强调的是,目前国内教程上基本上都是说利用Java和Maven才能搭建,但实际上并不是这样,Maven只是个项目管理工具,对程序本身其实没多大影响,建立好依赖库即可。

  步骤如下:

1.下载tensorflow支持的移植jar包:libtensorflow.jar和移植所需的JNI:libtensorflow_jni.dll (windows cpu version)

下载地址:https://www.tensorflow.org/install/lang_java?hl=zh_cn

2.在Java项目中添加jar包,如图所示选择Add External Jars..添加libtensorflow.jar

3.在Run as configuration修改vm arguments,添加jni路径,如下图所示

4.利用tensorflow测试程序试试吧。

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
 
public class HelloTensorFlow {
  public static void main(String[] args) throws Exception {
    try (Graph g = new Graph()) {
      final String value = "Hello from " + TensorFlow.version();
 
      // Construct the computation graph with a single operation, a constant
      // named "MyConst" with a value "value".
      try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
        // The Java API doesn't yet include convenience functions for adding operations.
        g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
      }
 
      // Execute the "MyConst" operation in a Session.
      try (Session s = new Session(g);
          // Generally, there may be multiple output tensors,
          // all of them must be closed to prevent resource leaks.
          Tensor output = s.runner().fetch("MyConst").run().get(0)) {
        System.out.println(new String(output.bytesValue(), "UTF-8"));
      }
    }
  }
}

5.测试没问题的话,就可以开始图像分类啦

官方提供了java端的调用代码:LabelImage

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java

到这一步为止,整个过程就完成了,接下来盘几个坑吧。

第三步:盘坑

坑1:图像多分类结果混乱问题。

java上预测结果和python上不一样了,明明python上那么准怎么java上不好使了?

这个坑很有可能就是你在java上进行图像规范化的格式问题:

private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
    try (Graph g = new Graph()) {
      GraphBuilder b = new GraphBuilder(g);
      // Some constants specific to the pre-trained model at:
      // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
      //
      // - The model was trained with images scaled to 224x224 pixels.
      // - The colors, represented as R, G, B in 1-byte each were converted to
      //   float using (value - Mean)/Scale.
      final int H = 224;
      final int W = 224;
      final float mean = 117f;
      final float scale = 1f;
 
      // Since the graph is being constructed once per execution here, we can use a constant for the
      // input image. If the graph were to be re-used for multiple input images, a placeholder would
      // have been more appropriate.
      final Output<String> input = b.constant("input", imageBytes);
      final Output<Float> output =
          b.div(
              b.sub(
                  b.resizeBilinear(
                      b.expandDims(
                          b.cast(b.decodeJpeg(input, 3), Float.class),
                          b.constant("make_batch", 0)),
                      b.constant("size", new int[] {H, W})),
                  b.constant("mean", mean)),
              b.constant("scale", scale));
      try (Session s = new Session(g)) {
        // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
      }
    }

原本的mean和scale的设置是117f、1f。但实际RGB图像中这两个参数应改为128f,修改后预测结果恢复正常。

坑2:关于模型和label路径问题

  此条针对JavaEE项目。在JavaEE项目中,由于是通过tomcat启动,你的识别模型和label都应放到webroot中,可以新建一个模型文件夹存储,并在配置文件中添加相对位置,这样就避免了程序报错路径问题,很多都是提示failed to read ...诸如此类。如果你不是JavaEE只是JavaSE,直接run as Java Application的话,那就忽略此条,只要对应到你的本地路径即可。

坑3:无法找到Graph类的问题(java.lang.noclassdeffounderror: org/tensorflow/graph)

该问题还是针对JavaEE项目。

先说解决办法:

将“libtensorflow-1.12.0.jar”复制一份到WebRoot\WEB-INF\lib下,刷新项目即可。

这个问题主要还是tomcat在执行的时候加载不到tensorflow的移植包,之前的jar包只是java编译器上能调用,你要是JavaSE项目就都好使,但是web程序一旦运行,tomcat就找不到你的jar包。

暂时就盘点这么多坑吧,要是你们还出现了啥子奇葩的坑,欢迎评论留言。