java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > SpringBoot tensorflow图片检测

SpringBoot集成tensorflow实现图片检测功能

作者:HBLOG

TensorFlow名字的由来就是张量(Tensor)在计算图(Computational Graph)里的流动(Flow),它的基础就是前面介绍的基于计算图的自动微分,本文将给大家介绍Spring Boot集成tensorflow实现图片检测功能,需要的朋友可以参考下

1.什么是tensorflow?

TensorFlow名字的由来就是张量(Tensor)在计算图(Computational Graph)里的流动(Flow),如图。它的基础就是前面介绍的基于计算图的自动微分,除了自动帮你求梯度之外,它也提供了各种常见的操作(op,也就是计算图的节点),常见的损失函数,优化算法。

在机器学习中,数值通常由4种类型构成: (1)标量(scalar):即一个数值,它是计算的最小单元,如“1”或“3.2”等。 (2)向量(vector):由一些标量构成的一维数组,如[1, 3.2, 4.6]等。 (3)矩阵(matrix):是由标量构成的二维数组。 (4)张量(tensor):由多维(通常)数组构成的数据集合,可理解为高维矩阵。

tensorflow的基本概念

tensorflow写代码流程

2.环境准备

整合步骤

在整合过程中,有几个关键点需要注意。首先,防火墙设置可能会影响TensorFlow训练过程中的网络通信。确保你的防火墙允许TensorFlow访问其所需的网络资源,以免出现训练中断或模型性能下降的问题。其次,要关注版本兼容性。SpringBoot和TensorFlow都有各自的版本更新周期,确保在整合时使用兼容的版本可以避免很多不必要的麻烦。

3.代码工程

实验目的

实现图片检测

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>springboot-demo</artifactId>
        <groupId>com.et</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>Tensorflow</artifactId>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
    </properties>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-autoconfigure</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.5.0</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>jmimemagic</groupId>
            <artifactId>jmimemagic</artifactId>
            <version>0.1.2</version>
        </dependency>
        <dependency>
            <groupId>jakarta.platform</groupId>
            <artifactId>jakarta.jakartaee-api</artifactId>
            <version>9.0.0</version>
        </dependency>
        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.16.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.restdocs</groupId>
            <artifactId>spring-restdocs-mockmvc</artifactId>
            <scope>test</scope>
        </dependency>

    </dependencies>
</project>

controller

package com.et.tf.api;

import java.io.IOException;

import com.et.tf.service.ClassifyImageService;
import net.sf.jmimemagic.Magic;
import net.sf.jmimemagic.MagicMatch;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@RestController
@RequestMapping("/api")
public class AppController {
    @Autowired
    ClassifyImageService classifyImageService;


    @PostMapping(value = "/classify")
    @CrossOrigin(origins = "*")
    public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException {
        checkImageContents(file);
        return classifyImageService.classifyImage(file.getBytes());
    }

    @RequestMapping(value = "/")
    public String index() {
        return "index";
    }

    private void checkImageContents(MultipartFile file) {
        MagicMatch match;
        try {
            match = Magic.getMagicMatch(file.getBytes());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        String mimeType = match.getMimeType();
        if (!mimeType.startsWith("image")) {
            throw new IllegalArgumentException("Not an image type: " + mimeType);
        }
    }

}

service

package com.et.tf.service;

import jakarta.annotation.PreDestroy;
import java.util.Arrays;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.OpScope;
import org.tensorflow.op.Scope;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

//Inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@Service
@Slf4j
public class ClassifyImageService {

    private final Session session;
    private final List<String> labels;
    private final String outputLayer;

    private final int W;
    private final int H;
    private final float mean;
    private final float scale;

    public ClassifyImageService(
        Graph inceptionGraph, List<String> labels, @Value("${tf.outputLayer}") String outputLayer,
        @Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH,
        @Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale
    ) {
        this.labels = labels;
        this.outputLayer = outputLayer;
        this.H = imageH;
        this.W = imageW;
        this.mean = mean;
        this.scale = scale;
        this.session = new Session(inceptionGraph);
    }

    public LabelWithProbability classifyImage(byte[] imageBytes) {
        long start = System.currentTimeMillis();
        try (Tensor image = normalizedImageToTensor(imageBytes)) {
            float[] labelProbabilities = classifyImageProbabilities(image);
            int bestLabelIdx = maxIndex(labelProbabilities);
            LabelWithProbability labelWithProbability =
                new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start);
            log.debug(String.format(
                    "Image classification [%s %.2f%%] took %d ms",
                    labelWithProbability.getLabel(),
                    labelWithProbability.getProbability(),
                    labelWithProbability.getElapsed()
                )
            );
            return labelWithProbability;
        }
    }

    private float[] classifyImageProbabilities(Tensor image) {
        try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) {
            final Shape resultShape = result.shape();
            final long[] rShape = resultShape.asArray();
            if (resultShape.numDimensions() != 2 || rShape[0] != 1) {
                throw new RuntimeException(
                    String.format(
                        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                        Arrays.toString(rShape)
                    ));
            }
            int nlabels = (int) rShape[1];
            FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats();
            float[] dst = new float[nlabels];
            resultFloatBuffer.read(dst);
            return dst;
        }
    }

    private int maxIndex(float[] probabilities) {
        int best = 0;
        for (int i = 1; i < probabilities.length; ++i) {
            if (probabilities[i] > probabilities[best]) {
                best = i;
            }
        }
        return best;
    }

    private Tensor normalizedImageToTensor(byte[] imageBytes) {
        try (Graph g = new Graph();
             TInt32 batchTensor = TInt32.scalarOf(0);
             TInt32 sizeTensor = TInt32.vectorOf(H, W);
             TFloat32 meanTensor = TFloat32.scalarOf(mean);
             TFloat32 scaleTensor = TFloat32.scalarOf(scale);
        ) {
            GraphBuilder b = new GraphBuilder(g);
            //Tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image
            // Some constants specific to the pre-trained model at:
            // https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
            //
            // - The model was trained with images scaled to 299x299 pixels.
            // - The colors, represented as R, G, B in 1-byte each were converted to
            //   float using (value - Mean)/Scale.

            // Since the graph is being constructed once per execution here, we can use a constant for the
            // input image. If the graph were to be re-used for multiple input images, a placeholder would
            // have been more appropriate.
            final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)));
            final Output output =
                b.div(
                    b.sub(
                        b.resizeBilinear(
                            b.expandDims(
                                b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT),
                                b.constant("make_batch", batchTensor)
                            ),
                            b.constant("size", sizeTensor)
                        ),
                        b.constant("mean", meanTensor)
                    ),
                    b.constant("scale", scaleTensor)
                );
            try (Session s = new Session(g)) {
                return s.runner().fetch(output.op().name()).run().get(0);
            }
        }
    }

    static class GraphBuilder {
        final Scope scope;

        GraphBuilder(Graph g) {
            this.g = g;
            this.scope = new OpScope(g);
        }

        Output div(Output x, Output y) {
            return binaryOp("Div", x, y);
        }

        Output sub(Output x, Output y) {
            return binaryOp("Sub", x, y);
        }

        Output resizeBilinear(Output images, Output size) {
            return binaryOp("ResizeBilinear", images, size);
        }

        Output expandDims(Output input, Output dim) {
            return binaryOp("ExpandDims", input, dim);
        }

        Output cast(Output value, DataType dtype) {
            return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0);
        }

        Output decodeJpeg(Output contents, long channels) {
            return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope)
                .addInput(contents)
                .setAttr("channels", channels)
                .build()
                .output(0);
        }

        Output<? extends TType> constant(String name, Tensor t) {
            return g.opBuilder("Const", name, scope)
                .setAttr("dtype", t.dataType())
                .setAttr("value", t)
                .build()
                .output(0);
        }

        private Output binaryOp(String type, Output in1, Output in2) {
            return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0);
        }

        private final Graph g;
    }

    @PreDestroy
    public void close() {
        session.close();
    }

    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    public static class LabelWithProbability {
        private String label;
        private float probability;
        private long elapsed;
    }
}

application.yaml

tf:
    frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pb
    labelsPath: inception-v3/imagenet_slim_labels.txt
    outputLayer: InceptionV3/Predictions/Reshape_1
    image:
        width: 299
        height: 299
        mean: 0
        scale: 255

logging.level.net.sf.jmimemagic: WARN
spring:
  servlet:
    multipart:
      max-file-size: 5MB

Application.java

package com.et.tf;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.tensorflow.Graph;
import org.tensorflow.proto.framework.GraphDef;

@SpringBootApplication
@Slf4j
public class Application {

    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }

    @Bean
    public Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException {
        Resource graphResource = getResource(tfFrozenModelPath);

        Graph graph = new Graph();
        graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream()));
        log.info("Loaded Tensorflow model");
        return graph;
    }

    private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) {
        Resource graphResource = new FileSystemResource(tfFrozenModelPath);
        if (!graphResource.exists()) {
            graphResource = new ClassPathResource(tfFrozenModelPath);
        }
        if (!graphResource.exists()) {
            throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath));
        }
        return graphResource;
    }

    @Bean
    public List<String> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException {
        Resource labelsRes = getResource(labelsPath);
        log.info("Loaded model labels");
        return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream()
            .map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList());
    }
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

https://github.com/Harries/springboot-demo

4.测试

启动 Spring Boot应用程序

测试图片分类

访问http://127.0.0.1:8080/,上传一张图片,点击分类

5.总结

以上就是SpringBoot集成tensorflow实现图片检测功能的详细内容,更多关于SpringBoot tensorflow图片检测的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文