Java调用Pytorch实现以图搜图功能
作者:老李笔记
这篇文章主要为大家详细介绍了Java如何调用Pytorch实现以图搜图功能,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的小伙伴可以了解一下
Java调用Pytorch实现以图搜图
设计技术栈
1、ElasticSearch环境;
2、Python运行环境(如果事先没有pytorch模型时,可以用python脚本创建模型);
1、运行效果
2、创建模型(有则可以跳过)
1.vi script.py
import torch import torch.nn as nn import torchvision.models as models class ImageFeatureExtractor(nn.Module): def __init__(self): super(ImageFeatureExtractor, self).__init__() self.resnet = models.resnet50(pretrained=True) #最终输出维度1024的向量,下文elastic search要设置dims为1024 self.resnet.fc = nn.Linear(2048, 1024) def forward(self, x): x = self.resnet(x) return x if __name__ == '__main__': model = ImageFeatureExtractor() model.eval() #根据模型随便创建一个输入 input = torch.rand([1, 3, 224, 224]) output = model(input) #以这种方式保存 script = torch.jit.trace(model, input) script.save("model.pt")
2、java项目pom.xml
<dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <scope>provided</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.19.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <version>1.10.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>1.10.0-0.19.0</version> </dependency> <dependency> <groupId>org.elasticsearch.client</groupId> <artifactId>elasticsearch-rest-high-level-client</artifactId> </dependency> </dependencies>
3、ES创建文档
PUT /isi { "mappings": { "properties": { "vector": { "type": "dense_vector", "dims": 1024 }, "url" : { "type" : "keyword" }, "user_id": { "type": "keyword" } } } }
4、编写java代码调用模型
ORCUtil.java
package com.topprismcloud.rtm; import ai.djl.Device; import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.util.NDImageUtils; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.translate.Transform; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; import org.apache.http.auth.UsernamePasswordCredentials; import org.apache.http.client.CredentialsProvider; import org.apache.http.impl.client.BasicCredentialsProvider; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.RestClient; import org.elasticsearch.client.RestClientBuilder; import org.elasticsearch.client.RestHighLevelClient; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.ScriptQueryBuilder; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.XContentType; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URL; import java.nio.file.Paths; import java.util.*; public class ORCUtil { private static final String INDEX = "isi"; private static final int IMAGE_SIZE = 224; private static Model model; // 模型 private static Predictor<Image, float[]> predictor; // predictor.predict(input)相当于python中model(input) static { try { model = Model.newInstance("model"); // 这里的model.pt是上面代码展示的那种方式保存的 model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt")); Transform resize = new Resize(IMAGE_SIZE); Transform toTensor = new ToTensor(); Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f }, new float[] { 0.229f, 0.224f, 0.225f }); // Translator处理输入Image转为tensor、输出转为float[] Translator<Image, float[]> translator = new Translator<Image, float[]>() { @Override public NDList processInput(TranslatorContext ctx, Image input) throws Exception { NDManager ndManager = ctx.getNDManager(); System.out.println("input: " + input.getWidth() + ", " + input.getHeight()); NDArray transform = normalize .transform(toTensor.transform(resize.transform(input.toNDArray(ndManager)))); System.out.println(transform.getShape()); NDList list = new NDList(); list.add(transform); return list; } @Override public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception { return ndList.get(0).toFloatArray(); } }; predictor = new Predictor<>(model, translator, Device.cpu(), true); } catch (Exception e) { e.printStackTrace(); } } public static void upload() throws Exception { HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME); RestClientBuilder builder=RestClient.builder(host); CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456")); builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider)); RestHighLevelClient client = new RestHighLevelClient( builder); // 批量上传请求 BulkRequest bulkRequest = new BulkRequest(INDEX); File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new"); for (File listFile : file.listFiles()) { // float[] vector = predictor.predict(ImageFactory.getInstance() // .fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName()))); float[] vector = predictor.predict(ImageFactory.getInstance() .fromInputStream(new FileInputStream(listFile))); // 构建文档 Map<String, Object> jsonMap = new HashMap<>(); jsonMap.put("url", "/resource/"+listFile.getName()); jsonMap.put("vector", vector); jsonMap.put("user_id", "user123"); IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON); bulkRequest.add(request); } client.bulk(bulkRequest, RequestOptions.DEFAULT); client.close(); } // 接收待搜索图片的inputstream,搜索与其相似的图片 public static List<SearchResult> search(InputStream input) throws Throwable { float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input)); System.out.println(Arrays.toString(vector)); // 展示k个结果 int k = 100; // 连接Elasticsearch服务器 RestHighLevelClient client = new RestHighLevelClient( RestClient.builder(new HttpHost("14.20.30.16", 9200, "http"))); SearchRequest searchRequest = new SearchRequest(INDEX); Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])", Collections.singletonMap("queryVector", vector)); FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders .functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时 .size(k); searchRequest.source(searchSourceBuilder); SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); SearchHits hits = searchResponse.getHits(); List<SearchResult> list = new ArrayList<>(); for (SearchHit hit : hits) { // 处理搜索结果 System.out.println(hit.toString()); SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore()); list.add(result); } client.close(); return list; } public static void main(String[] args) throws Throwable { ORCUtil.upload(); System.out.println("hao"); } }
SearchController.java
package com.topprismcloud.rtm; import java.util.List; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.CrossOrigin; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; @RestController @CrossOrigin public class SearchController { @PostMapping("search") public ResponseEntity search(MultipartFile file) { try { List<SearchResult> list = ORCUtil.search(file.getInputStream()); return ResponseEntity.ok(list); } catch (Throwable e) { return ResponseEntity.status(400).body(null); } } }
SearchResult.java
package com.topprismcloud.rtm; import lombok.AllArgsConstructor; import lombok.Data; @Data @AllArgsConstructor public class SearchResult { private String url; private Float score; }
5、前端
index.html
<!DOCTYPE html> <html lang="zh"> <head> <meta charset="UTF-8"> <title>以图搜图</title> <style> body { background: url("/img/bg.jpg"); background-attachment: fixed; background-size: 100% 100%; } body>div { width: 1000px; margin: 50px auto; padding: 10px 20px; border: 1px solid lightgray; border-radius: 20px; box-sizing: border-box; background: rgba(255, 255, 255, 0.7); } .upload { display: inline-block; width: 300px; height: 280px; border: 1px dashed lightcoral; vertical-align: top; } .upload .cover { width: 200px; height: 200px; margin: 10px 50px; border: 1px solid black; box-sizing: border-box; text-align: center; line-height: 200px; position: relative; } .upload img { width: 198px; height: 198px; position: absolute; left: 0; top: 0; } .upload input { margin-left: 50px; } .upload button { width: 80px; height: 30px; margin-left: 110px; } .result-block { display: inline-block; margin-left: 40px; border: 1px solid lightgray; border-radius: 10px; min-height: 500px; width: 600px; } .result-block h1 { text-align: center; margin-top: 100px; } .result { padding: 10px; cursor: pointer; display: inline-block; } .result:hover { background: rgb(240, 240, 240); } .result p { width: 110px; overflow: hidden; white-space: nowrap; text-overflow: ellipsis; } .result img { width: 160px; height: 160px; } .result .prob { color: rgb(37, 147, 60) } </style> <script src="js/jquery-3.6.0.js"></script> </head> <body> <div> <div class="upload"> <div class="cover"> 请选择图片 <img id="image" src="" /> </div> <input id="file" type="file"> </div> <div class="result-block"> <h1>请选择图片</h1> </div> </div> <ul id="box"> </ul> <script> var file = $('#file') file.change(function () { let f = this.files[0] let index = f.name.lastIndexOf('.') let fileText = f.name.substring(index, f.name.length) let ext = fileText.toLowerCase() //文件类型 console.log(ext) if (ext != '.png' && ext != '.jpg' && ext != '.jpeg') { alert('系统仅支持 JPG、PNG、JPEG 格式的图片,请您调整格式后重新上传') return } $('.result-block').empty().append($('<h1>正在识别中...</h1>')) $("#image").attr("src", getObjectURL(f)); let formData = new FormData() formData.append('file', f) $.ajax({ url: 'http://10.1.2.240:8081/search', method: 'post', data: formData, processData: false, contentType: false, success: res => { console.log('shibie', res) $('.result-block').empty() for (let item of res) { console.log(item) let html = `<div class="result"> <img src="${item.url}"/> <div style="display: inline-block;vertical-align: top"> <p class="prob">得分:${item.score.toFixed(4)}</p> </div> </div>` $('.result-block').append($(html)) } } }) }); $('#button').click(function (e) { var file = $('#file')[0].files[0] //单个 console.log(file) }) function getObjectURL(file) { var url = null; if (window.createObjcectURL != undefined) { url = window.createOjcectURL(file); } else if (window.URL != undefined) { url = window.URL.createObjectURL(file); } else if (window.webkitURL != undefined) { url = window.webkitURL.createObjectURL(file); } return url; } function detect() { } </script> </body> </html>
相关参考文章:Java调用Pytorch模型实现图像识别
以上就是Java调用Pytorch实现以图搜图功能的详细内容,更多关于Java以图搜图的资料请关注脚本之家其它相关文章!