SpringBoot集成DJL实现图片分类功能
作者:HBLOG
DJL是一个使用Java API简化模型训练、测试、部署和使用深度学习模型进行推理的开源库深度学习工具包,开源的许可协议是Apache-2.0,本文给大家介绍了SpringBoot集成DJL实现图片分类功能,需要的朋友可以参考下
1.什么是DJL?
DJL 是一个很新的项目,在2019年12月初的AWS re: invest大会上才正式的发布出来。。简单来说,DJL是一个使用Java API简化模型训练、测试、部署和使用深度学习模型进行推理的开源库深度学习工具包,开源的许可协议是Apache-2.0。对于Java开发者而言,可以在Java中开发及应用原生的机器学习和深度学习模型,同时简化了深度学习开发的难度。通过DJL提供的直观的、高级的API,Java开发人员可以训练自己的模型,或者利用数据科学家用Python预先训练好的模型来进行推理。如果您恰好是对学习深度学习感兴趣的Java开发者,那么DJL无疑将是开始深度学习应用的一个最好的起点。
2.数据准备
下载训练集
wget https://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images-square.zip
解压,方便后面训练模型使用
unzip ut-zap50k-images-square.zip
3.代码工程
实验目的
基于djl实现图片分类
pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>3.2.1</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>djl</artifactId> <properties> <maven.compiler.source>17</maven.compiler.source> <maven.compiler.target>17</maven.compiler.target> </properties> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> <!-- DJL --> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>basicdataset</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>model-zoo</artifactId> </dependency> <!-- pytorch-engine--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <scope>runtime</scope> </dependency> </dependencies> <profiles> <profile> <id>windows</id> <activation> <activeByDefault>true</activeByDefault> </activation> <dependencies> <!-- Windows CPU --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>win-x86_64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>centos7</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- For Pre-CXX11 build (CentOS7)--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-x86_64</classifier> <version>2.0.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>linux</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- Linux CPU --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>linux-x86_64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>aarch64</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- For aarch64 build--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-aarch64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> </profiles> <dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>bom</artifactId> <version>0.23.0</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> </project>
conotroller
package com.et.controller; import ai.djl.MalformedModelException; import ai.djl.translate.TranslateException; import com.et.service.ImageClassificationService; import lombok.RequiredArgsConstructor; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.stream.Stream; @RestController @RequiredArgsConstructor public class ImageClassificationController { private final ImageClassificationService imageClassificationService; @PostMapping(path = "/analyze") public String predict(@RequestPart("image") MultipartFile image, @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, MalformedModelException, IOException { return imageClassificationService.predict(image, modePath); } @PostMapping(path = "/training") public String training(@RequestParam(defaultValue = "/home/djl-test/images-test") String datasetRoot, @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException { return imageClassificationService.training(datasetRoot, modePath); } @GetMapping("/download") public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) { List<String> imgPathList = new ArrayList<>(); try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) { // Filter only regular files (excluding directories) paths.filter(Files::isRegularFile) .forEach(c-> imgPathList.add(c.toString())); } catch (IOException e) { return ResponseEntity.status(500).build(); } Random random = new Random(); String filePath = imgPathList.get(random.nextInt(imgPathList.size())); Path file = Paths.get(filePath); Resource resource = new FileSystemResource(file.toFile()); if (!resource.exists()) { return ResponseEntity.notFound().build(); } HttpHeaders headers = new HttpHeaders(); headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString()); headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE); try { return ResponseEntity.ok() .headers(headers) .contentLength(resource.contentLength()) .body(resource); } catch (IOException e) { return ResponseEntity.status(500).build(); } } }
service
接口
package com.et.service; import ai.djl.MalformedModelException; import ai.djl.translate.TranslateException; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; public interface ImageClassificationService { public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException; public String training(String datasetRoot, String modePath) throws TranslateException, IOException; }
实现类
package com.et.service; import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.basicdataset.cv.classification.ImageFolder; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; import ai.djl.modality.Classifications; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.translator.ImageClassificationTranslator; import ai.djl.ndarray.types.Shape; import ai.djl.training.*; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.evaluator.Accuracy; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import com.et.Models; import lombok.Cleanup; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.IOException; import java.io.InputStream; import java.nio.file.Path; import java.nio.file.Paths; @Slf4j @Service public class ImageClassificationServiceImpl implements ImageClassificationService { // represents number of training samples processed before the model is updated private static final int BATCH_SIZE = 32; // the number of passes over the complete dataset private static final int EPOCHS = 2; //the number of classification labels: boots, sandals, shoes, slippers @Value("${djl.num-of-output:4}") public int numOfOutput; @Override public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException { @Cleanup InputStream is = image.getInputStream(); Path modelDir = Paths.get(modePath); BufferedImage bi = ImageIO.read(is); Image img = ImageFactory.getInstance().fromImage(bi); // empty model instance try (Model model = Models.getModel(numOfOutput)) { // load the model model.load(modelDir, Models.MODEL_NAME); // define a translator for pre and post processing // out of the box this translator converts images to ResNet friendly ResNet 18 shape Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)) .addTransform(new ToTensor()) .optApplySoftmax(true) .build(); // run the inference using a Predictor try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) { // holds the probability score per label Classifications predictResult = predictor.predict(img); log.info("reusult={}",predictResult.toJson()); return predictResult.toJson(); } } } @Override public String training(String datasetRoot, String modePath) throws TranslateException, IOException { log.info("Image dataset training started...Image dataset address path:{}",datasetRoot); // the location to save the model Path modelDir = Paths.get(modePath); // create ImageFolder dataset from directory ImageFolder dataset = initDataset(datasetRoot); // Split the dataset set into training dataset and validate dataset RandomAccessDataset[] datasets = dataset.randomSplit(8, 2); // set loss function, which seeks to minimize errors // loss function evaluates model's predictions against the correct answer (during training) // higher numbers are bad - means model performed poorly; indicates more errors; want to // minimize errors (loss) Loss loss = Loss.softmaxCrossEntropyLoss(); // setting training parameters (ie hyperparameters) TrainingConfig config = setupTrainingConfig(loss); try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns Trainer trainer = model.newTrainer(config)) { // metrics collect and report key performance indicators, like accuracy trainer.setMetrics(new Metrics()); Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT); // initialize trainer with proper input shape trainer.initialize(inputShape); // find the patterns in data EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]); // set model properties TrainingResult result = trainer.getTrainingResult(); model.setProperty("Epoch", String.valueOf(EPOCHS)); model.setProperty( "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy"))); model.setProperty("Loss", String.format("%.5f", result.getValidateLoss())); // save the model after done training for inference later // model saved as shoeclassifier-0000.params model.save(modelDir, Models.MODEL_NAME); // save labels into model directory Models.saveSynset(modelDir, dataset.getSynset()); log.info("Image dataset training completed......"); return String.join("\n", dataset.getSynset()); } } private ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException { ImageFolder dataset = ImageFolder.builder() // retrieve the data .setRepositoryPath(Paths.get(datasetRoot)) .optMaxDepth(10) .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)) .addTransform(new ToTensor()) // random sampling; don't process the data in order .setSampling(BATCH_SIZE, true) .build(); dataset.prepare(); return dataset; } private TrainingConfig setupTrainingConfig(Loss loss) { return new DefaultTrainingConfig(loss) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging()); } }
application.yaml
server: port: 8888 spring: application: name: djl-image-classification-demo servlet: multipart: max-file-size: 100MB max-request-size: 100MB mvc: pathmatch: matching-strategy: ant_path_matcher
启动类
package com.et; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; @SpringBootApplication public class DemoApplication { public static void main(String[] args) { SpringApplication.run(DemoApplication.class, args); } }
以上只是一些关键代码,所有代码请参见下面代码仓库
代码仓库
4.测试
启动Spring Boot应用
训练模型
使用之前下载的数据集
控制台输出日志,如果没有gpu的话,训练有点慢,估计要等一会
2024-10-11T21:00:05.407+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] c.e.s.ImageClassificationServiceImpl : Image dataset training started...Image dataset address path:/Users/liuhaihua/ai/ut-zap50k-images-square 2024-10-11T21:00:08.455+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.util.Platform : Ignore mismatching platform from: jar:file:/Users/liuhaihua/.m2/repository/ai/djl/pytorch/pytorch-native-cpu/2.0.1/pytorch-native-cpu-2.0.1-win-x86_64.jar!/native/lib/pytorch.properties 2024-10-11T21:00:09.240+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization 2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of inter-op threads is 4 2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of intra-op threads is 4 2024-10-11T21:00:09.287+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Training on: cpu(). 2024-10-11T21:00:09.290+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Load PyTorch Engine Version 1.13.1 in 0.044 ms. Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38 Validating: 100% |████████████████████████████████████████| 2024-10-11T22:42:48.142+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Epoch 1 finished. 2024-10-11T22:42:48.187+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38 2024-10-11T22:42:48.189+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.24 Training: 5% |███ | Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
预测图片分类
使用上一步训练出来的模型进行预测
根据返回的结果看见鞋子的概率最高,由此可见该图片所属的鞋类为 Shoes
以上就是SpringBoot集成DJL实现图片分类功能的详细内容,更多关于SpringBoot DJL图片分类的资料请关注脚本之家其它相关文章!