java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > SpringBoot Milvus和deeplearning4图搜图

SpringBoot集成Milvus和deeplearning4j实现图搜图功能

作者:HBLOG

Milvus 是一种高性能、高扩展性的向量数据库,可在从笔记本电脑到大型分布式系统等各种环境中高效运行,Deeplearning4j(DL4J)是一个开源的深度学习框架,专门为Java和Scala开发,本文给大家介绍了SpringBoot集成Milvus和deeplearning4j实现图搜图功能

1.什么是Milvus?

Milvus 是一种高性能、高扩展性的向量数据库,可在从笔记本电脑到大型分布式系统等各种环境中高效运行。它既可以开源软件的形式提供,也可以云服务的形式提供。 Milvus 是 LF AI & Data Foundation 下的一个开源项目,以 Apache 2.0 许可发布。大多数贡献者都是高性能计算(HPC)领域的专家,擅长构建大型系统和优化硬件感知代码。核心贡献者包括来自 Zilliz、ARM、NVIDIA、AMD、英特尔、Meta、IBM、Salesforce、阿里巴巴和微软的专业人士

2.什么是deeplearning4j?

Deeplearning4j(DL4J)是一个开源的深度学习框架,专门为Java和Scala开发。它支持分布式计算,适合在大数据环境中运行,比如与Hadoop或Spark集成。DL4J的特点包括:

Deeplearning4j是企业和开发者进行深度学习开发和研究的强大工具,特别适合于需要与Java生态系统兼容的场景。

3.环境搭建

4.代码工程

实验目标

利用Milvus和deeplearning4j实现图搜图功能

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>Milvus</artifactId>

    <properties>
        <maven.compiler.source>17</maven.compiler.source>
        <maven.compiler.target>17</maven.compiler.target>
        <deeplearning4j.version>1.0.0-M2.1</deeplearning4j.version>
        <nd4j.version>1.0.0-M2.1</nd4j.version>
    </properties>
    <dependencies>


        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-autoconfigure</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>io.milvus</groupId>
            <artifactId>milvus-sdk-java</artifactId>
            <version>2.4.0</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-zoo</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>${nd4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-data-image</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-modelimport</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>

    </dependencies>

    <build>
        <pluginManagement>
            <plugins>
                <plugin>
                    <groupId>org.apache.maven.plugins</groupId>
                    <artifactId>maven-compiler-plugin</artifactId>
                    <version>3.8.1</version>
                    <configuration>
                        <fork>true</fork>
                        <failOnError>false</failOnError>
                    </configuration>
                </plugin>

                <plugin>
                    <groupId>org.apache.maven.plugins</groupId>
                    <artifactId>maven-surefire-plugin</artifactId>
                    <version>2.22.2</version>
                    <configuration>
                        <forkCount>0</forkCount>
                        <failIfNoTests>false</failIfNoTests>
                    </configuration>
                </plugin>
            </plugins>
        </pluginManagement>
    </build>
</project>

特征抽取

package com.et.imagesearch;

import org.deeplearning4j.zoo.model.ResNet50;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.datavec.image.loader.NativeImageLoader;

import java.io.File;
import java.io.IOException;

public class FeatureExtractor {
    private ComputationGraph model;

   public FeatureExtractor() throws IOException {
      try {
         ZooModel<ComputationGraph> zooModel = ResNet50.builder().build();
         model = (ComputationGraph) zooModel.initPretrained();
      } catch (Exception e) {
         throw new IOException("Failed to initialize the pre-trained model: " + e.getMessage(), e);
      }
   }

    public INDArray extractFeatures(File imageFile) throws IOException {
        NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
        INDArray image = loader.asMatrix(imageFile);
        ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
        scaler.transform(image);

        return model.outputSingle(image);
    }
}

Milvus数据库操作

package com.et.imagesearch;

import io.milvus.client.*;
import io.milvus.param.*;
import io.milvus.param.collection.*;
import io.milvus.param.dml.*;
import io.milvus.grpc.*;
import io.milvus.param.index.CreateIndexParam;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class MilvusManager {
    private  MilvusServiceClient milvusClient;

    public MilvusManager() {
      milvusClient = new MilvusServiceClient(
            ConnectParam.newBuilder()
                  .withUri("https://xxx.gcp-us-west1.cloud.zilliz.com")
                  .withToken("xxx")
                  .build());
    }

    public void createCollection() {
        FieldType idField = FieldType.newBuilder()
                .withName("id")
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .build();

        FieldType vectorField = FieldType.newBuilder()
                .withName("embedding")
                .withDataType(DataType.FloatVector)
                .withDimension(1000)
                .build();

        CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
                .withCollectionName("image_collection")
                .withDescription("Image collection")
                .withShardsNum(2)
                .addFieldType(idField)
                .addFieldType(vectorField)
                .build();

        milvusClient.createCollection(createCollectionParam);
    }

    public void insertData(long id, INDArray features) {
        List<Long> ids = Collections.singletonList(id);
        float[] floatArray = features.toFloatVector();

        List<Float> floatList = new ArrayList<>();
        for (float f : floatArray) {
            floatList.add(f); 
        }

        List<List<Float>> vectors = Collections.singletonList(floatList);

        List<InsertParam.Field> fields = new ArrayList<>();
        fields.add(new InsertParam.Field("id",ids));
        fields.add(new InsertParam.Field("embedding", vectors));
        InsertParam insertParam = InsertParam.newBuilder()
                .withCollectionName("image_collection")
                .withFields(fields)
                .build();

        milvusClient.insert(insertParam);

    }
   public void flush() {
      milvusClient.flush(FlushParam.newBuilder()
            .withCollectionNames(Collections.singletonList("image_collection"))
            .withSyncFlush(true)
            .withSyncFlushWaitingInterval(50L)
            .withSyncFlushWaitingTimeout(30L)
            .build());
   }

   public void buildindex() {
      // build index
      System.out.println("Building AutoIndex...");
      final IndexType INDEX_TYPE = IndexType.AUTOINDEX;   // IndexType
      long startIndexTime = System.currentTimeMillis();
      R<RpcStatus> indexR = milvusClient.createIndex(
            CreateIndexParam.newBuilder()
                  .withCollectionName("image_collection")
                  .withFieldName("embedding")
                  .withIndexType(INDEX_TYPE)
                  .withMetricType(MetricType.L2)
                  .withSyncMode(Boolean.TRUE)
                  .withSyncWaitingInterval(500L)
                  .withSyncWaitingTimeout(30L)
                  .build());
      long endIndexTime = System.currentTimeMillis();
      System.out.println("Succeed in " + (endIndexTime - startIndexTime) / 1000.00 + " seconds!");
   }
}

图片搜索功能

package com.et.imagesearch;

import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.SearchResults;
import io.milvus.param.ConnectParam;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.dml.SearchParam;
import io.milvus.response.SearchResultsWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class ImageSearcher {
   private  MilvusServiceClient milvusClient;

   public ImageSearcher() {
      milvusClient = new MilvusServiceClient(
            ConnectParam.newBuilder()
                  .withUri("https://ixxxxx.gcp-us-west1.cloud.zilliz.com")
                  .withToken("xxx")
                  .build());
   }

   public void search(INDArray queryFeatures) {
      float[] floatArray = queryFeatures.toFloatVector();
      List<Float> floatList = new ArrayList<>();
      for (float f : floatArray) {
         floatList.add(f);
      }
      List<List<Float>> vectors = Collections.singletonList(floatList);


      SearchParam searchParam = SearchParam.newBuilder()
            .withCollectionName("image_collection")
            .withMetricType(MetricType.L2)
            .withTopK(5)
            .withVectors(vectors)
            .withVectorFieldName("embedding")
            .build();

      R<SearchResults> searchResults = milvusClient.search(searchParam);


      System.out.println("Searching vector: " + queryFeatures.toFloatVector());
      System.out.println("Result: " + searchResults.getData().getResults().getFieldsDataList());
   }
}

Main主类

package com.et.imagesearch;

import org.nd4j.linalg.api.ndarray.INDArray;

import java.io.File;
import java.io.IOException;

public class Main {
    public static void main(String[] args) throws IOException {
        FeatureExtractor extractor = new FeatureExtractor();
        MilvusManager milvusManager = new MilvusManager();
        ImageSearcher searcher = new ImageSearcher();

        milvusManager.createCollection();

        // images extract
        File[] imageFiles = new File("/Users/liuhaihua/ai/ut-zap50k-images-square/Boots/Ankle/Columbia").listFiles();
        if (imageFiles != null) {
            for (int i = 0; i < imageFiles.length; i++) {
                INDArray features = extractor.extractFeatures(imageFiles[i]);
                milvusManager.insertData(i, features);
            }
        }
      milvusManager.flush();
      milvusManager.buildindex();


        // query
        File queryImage = new File("/Users/liuhaihua/ai/ut-zap50k-images-square/Boots/Ankle/Columbia/7247580.16952.jpg");
        INDArray queryFeatures = extractor.extractFeatures(queryImage);
        searcher.search(queryFeatures);
    }
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

5.测试

以上就是SpringBoot集成Milvus和deeplearning4j实现图搜图功能的详细内容,更多关于SpringBoot Milvus和deeplearning4图搜图的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文