Java集成和使用dl4j实现通过扫描图片识别快递单信息
作者:慧香一格
使用DL4J(DeepLearning4J)搭建一个简单的图像识别模型,并将其集成到Spring Boot后端中。我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN),然后将其部署到Spring Boot应用中。
1. 设置Spring Boot项目
首先,创建一个新的Spring Boot项目。你可以使用Spring Initializr(https://start.spring.io/)来快速生成项目结构。选择以下依赖:
Spring Web
Spring Boot DevTools
Lombok(可选,用于简化代码)
2. 添加DL4J依赖
在你的pom.xml文件中添加DL4J和相关依赖:
<dependencies> <!-- Spring Boot Web --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <!-- DL4J --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-nn</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-ui</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-beta7</version> </dependency> <!-- File Upload --> <dependency> <groupId>commons-fileupload</groupId> <artifactId>commons-fileupload</artifactId> <version>1.4</version> </dependency> <dependency> <groupId>commons-io</groupId> <artifactId>commons-io</artifactId> <version>2.11.0</version> </dependency> <!-- Lombok (optional) --> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> </dependencies>
3. 训练DL4J模型
我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN)。创建一个新的Java类MnistModelTrainer.java来训练模型:
package com.example.scanapp; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.layers.objdetect.YoloOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.learning.config.Adam; import java.io.File; import java.io.IOException; public class MnistModelTrainer { public static void main(String[] args) throws IOException { int numEpochs = 10; int batchSize = 64; int numLabels = 10; int numRows = 28; int numColumns = 28; int numChannels = 1; // Load MNIST data DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); // Preprocess data ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1); mnistTrain.setPreProcessor(scaler); mnistTest.setPreProcessor(scaler); // Define the network architecture MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .updater(new Adam(0.001)) .list() .layer(0, new ConvolutionLayer.Builder(5, 5) .nIn(numChannels) .nOut(20) .stride(1, 1) .activation(Activation.RELU) .build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(2, new ConvolutionLayer.Builder(5, 5) .nOut(50) .stride(1, 1) .activation(Activation.RELU) .build()) .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(4, new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(numLabels) .activation(Activation.SOFTMAX) .build()) .build(); // Initialize the network MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(10)); // Train the network for (int i = 0; i < numEpochs; i++) { model.fit(mnistTrain); } // Save the model File locationToSave = new File("mnist-model.zip"); boolean saveUpdater = true; // Save the updater ModelSerializer.writeModel(model, locationToSave, saveUpdater); } }
运行MnistModelTrainer类来训练模型并保存到mnist-model.zip文件中。
4. 创建Spring Boot Controller
创建一个新的Controller来处理图片上传和图像识别:
package com.example.scanapp.controller; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.factory.Nd4j; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; @RestController public class ImageController { private static final String MODEL_PATH = "mnist-model.zip"; // 替换为你的模型路径 private MultiLayerNetwork model; private ImagePreProcessingScaler scaler; public ImageController() throws IOException { this.model = ModelSerializer.restoreMultiLayerNetwork(new File(MODEL_PATH)); this.scaler = new ImagePreProcessingScaler(0, 1); } @PostMapping("/recognize") public String recognize(@RequestParam("image") MultipartFile file) { try { BufferedImage image = ImageIO.read(file.getInputStream()); INDArray imageArray = Nd4j.create(new int[]{1, 1, 28, 28}); for (int i = 0; i < 28; i++) { for (int j = 0; j < 28; j++) { int rgb = image.getRGB(j, i); int gray = (rgb >> 16) & 0xFF; // Convert to grayscale imageArray.putScalar(0, 0, i, j, gray / 255.0); } } scaler.transform(imageArray); INDArray output = model.output(imageArray); int predictedClass = output.argMax(1).getInt(0); return "Predicted class: " + predictedClass; } catch (IOException e) { e.printStackTrace(); return "Error processing image"; } } }
5. 测试API
你可以使用Postman或其他工具来测试你的API。发送一个POST请求到/recognize端点,并附带一个MNIST格式的图片文件(28x28像素的灰度图像)。
6. 运行Spring Boot应用
确保你的Spring Boot应用能够正常启动。你可以通过以下命令运行应用:
mvn spring-boot:run
7. 前端集成(可选)
如果你有一个前端应用(例如Vue.js),你可以创建一个简单的表单来上传图片并调用后端API。以下是一个简单的Vue.js组件示例:
<template> <div> <h1>Image Recognition</h1> <input type="file" @change="onFileChange" accept="image/*" /> <button @click="uploadImage">Upload</button> <p v-if="result">{{ result }}</p> </div> </template> <script> export default { data() { return { file: null, result: '' }; }, methods: { onFileChange(e) { this.file = e.target.files[0]; }, async uploadImage() { const formData = new FormData(); formData.append('image', this.file); try { const response = await fetch('http://localhost:8080/recognize', { method: 'POST', body: formData }); const data = await response.text(); this.result = data; } catch (error) { console.error('Error uploading image:', error); } } } }; </script>
代码解读
将上述Vue.js组件添加到你的Vue项目中,然后运行前端应用来测试整个流程。
通过以上步骤,你应该能够成功搭建一个使用DL4J模型的Spring Boot后端服务,并通过前端应用进行图像识别。
以上就是Java集成和使用dl4j实现通过扫描图片识别快递单信息的详细内容,更多关于Java dl4j实现图片识别的资料请关注脚本之家其它相关文章!