java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > Spring Boot LangChain Rag应用

Spring Boot集成LangChain来实现Rag应用的问题小结

作者:HBLOGA

检索增强生成(RAG)是一种优化大型语言模型(LLM)输出的技术,通过引用权威知识库以增强模型的准确性和相关性,RAG允许LLM在不重新训练的情况下访问特定领域的知识,提高了其在各种应用中的实用性和信任度,感兴趣的朋友跟随小编一起看看吧

1.什么是rag?

检索增强生成(RAG)是指对大型语言模型输出进行优化,使其能够在生成响应之前引用训练数据来源之外的权威知识库。大型语言模型(LLM)用海量数据进行训练,使用数十亿个参数为回答问题、翻译语言和完成句子等任务生成原始输出。在 LLM 本就强大的功能基础上,RAG 将其扩展为能访问特定领域或组织的内部知识库,所有这些都无需重新训练模型。这是一种经济高效地改进 LLM 输出的方法,让它在各种情境下都能保持相关性、准确性和实用性。

为什么检索增强生成很重要?

LLM 是一项关键的人工智能(AI)技术,为智能聊天机器人和其他自然语言处理(NLP)应用程序提供支持。目标是通过交叉引用权威知识来源,创建能够在各种环境中回答用户问题的机器人。不幸的是,LLM 技术的本质在 LLM 响应中引入了不可预测性。此外,LLM 训练数据是静态的,并引入了其所掌握知识的截止日期。 LLM 面临的已知挑战包括:

您可以将大型语言模型看作是一个过于热情的新员工,他拒绝随时了解时事,但总是会绝对自信地回答每一个问题。不幸的是,这种态度会对用户的信任产生负面影响,这是您不希望聊天机器人效仿的! RAG 是解决其中一些挑战的一种方法。它会重定向 LLM,从权威的、预先确定的知识来源中检索相关信息。组织可以更好地控制生成的文本输出,并且用户可以深入了解 LLM 如何生成响应。

检索增强生成的工作原理是什么?

如果没有 RAG,LLM 会接受用户输入,并根据它所接受训练的信息或它已经知道的信息创建响应。RAG 引入了一个信息检索组件,该组件利用用户输入首先从新数据源提取信息。用户查询和相关信息都提供给 LLM。LLM 使用新知识及其训练数据来创建更好的响应。以下各部分概述了该过程。

创建外部数据

LLM 原始训练数据集之外的新数据称为外部数据。它可以来自多个数据来源,例如 API、数据库或文档存储库。数据可能以各种格式存在,例如文件、数据库记录或长篇文本。另一种称为嵌入语言模型的 AI 技术将数据转换为数字表示形式并将其存储在向量数据库中。这个过程会创建一个生成式人工智能模型可以理解的知识库。

检索相关信息

下一步是执行相关性搜索。用户查询将转换为向量表示形式,并与向量数据库匹配。例如,考虑一个可以回答组织的人力资源问题的智能聊天机器人。如果员工搜索:“我有多少年假?”,系统将检索年假政策文件以及员工个人过去的休假记录。这些特定文件将被退回,因为它们与员工输入的内容高度相关。相关性是使用数学向量计算和表示法计算和建立的。

增强 LLM 提示

接下来,RAG 模型通过在上下文中添加检索到的相关数据来增强用户输入(或提示)。此步骤使用提示工程技术与 LLM 进行有效沟通。增强提示允许大型语言模型为用户查询生成准确的答案。

更新外部数据

下一个问题可能是——如果外部数据过时了怎么办? 要维护当前信息以供检索,请异步更新文档并更新文档的嵌入表示形式。您可以通过自动化实时流程或定期批处理来执行此操作。这是数据分析中常见的挑战——可以使用不同的数据科学方法进行变更管理。 下图显示了将 RAG 与 LLM 配合使用的概念流程。

2.什么是LangChain?

LangChain 是一个用于开发由语言模型驱动的应用程序的框架。他主要拥有 2 个能力:

LLM 模型:Large Language Model,大型语言模型

3.代码工程

实验目的

利用LangChain实现rag应用

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <modelVersion>4.0.0</modelVersion>
    <artifactId>rag</artifactId>
    <properties>
        <java.version>17</java.version>
        <langchain4j.version>0.23.0</langchain4j.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-thymeleaf</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-devtools</artifactId>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j-open-ai</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j-embeddings</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

controller

package com.et.rag.controller;
import com.et.rag.service.SBotService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@Controller
@RequiredArgsConstructor
public class SBotController {
    private final SBotService sBotService;
    @GetMapping
    public String home() {
        return "index";
    }
    @PostMapping("/ask")
    public ResponseEntity<String> ask(@RequestBody String question) {
        try {
            return ResponseEntity.ok(sBotService.askQuestion(question));
        } catch (Exception e) {
            return ResponseEntity.badRequest().body("Sorry, I can't process your question right now.");
        }
    }
}

service

package com.et.rag.service;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
@Slf4j
public class SBotService {
    private final ConversationalRetrievalChain chain;
    public String askQuestion(String question) {
        log.debug("======================================================");
        log.debug("Question: " + question);
        String answer = chain.execute(question);
        log.debug("Answer: " + answer);
        log.debug("======================================================");
        return answer;
    }
}

EmbeddingStoreLoggingRetriever

package com.et.rag.retriever;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
/**
 * EmbeddingStoreLoggingRetriever is a logging-enhanced for an EmbeddingStoreRetriever.
 * <p>
 * This class logs the relevant TextSegments discovered by the supplied
 * EmbeddingStoreRetriever for improved transparency and debugging.
 * <p>
 * Logging happens at INFO level, printing each relevant TextSegment found
 * for a given input text once the findRelevant method is called.
 */
@RequiredArgsConstructor
@Slf4j
public class EmbeddingStoreLoggingRetriever implements Retriever<TextSegment> {
    private final EmbeddingStoreRetriever retriever;
    @Override
    public List<TextSegment> findRelevant(String text) {
        List<TextSegment> relevant = retriever.findRelevant(text);
        relevant.forEach(segment -> {
            log.debug("=======================================================");
            log.debug("Found relevant text segment: {}", segment);
        });
        return relevant;
    }
}

components

初始化documents

package com.et.rag.configuration;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.UrlDocumentLoader;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import static com.et.rag.constant.Constants.SPRING_BOOT_RESOURCES_LIST;
@Configuration
public class DocumentConfiguration {
    @Bean
    public List<Document> documents() {
        return SPRING_BOOT_RESOURCES_LIST.stream()
                .map(url -> {
                    try {
                        return UrlDocumentLoader.load(url);
                    } catch (Exception e) {
                        throw new RuntimeException("Failed to load document from " + url, e);
                    }
                })
                .toList();
    }
}

初始化langchain

package com.et.rag.configuration;
import com.et.rag.retriever.EmbeddingStoreLoggingRetriever;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;
import java.util.List;
import static com.et.rag.constant.Constants.PROMPT_TEMPLATE_2;
@Configuration
@RequiredArgsConstructor
@Slf4j
public class LangChainConfiguration {
    @Value("${langchain.api.key}")
    private String apiKey;
    @Value("${langchain.timeout}")
    private Long timeout;
    private final List<Document> documents;
    @Bean
    public ConversationalRetrievalChain chain() {
        EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
        EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
        EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
                .documentSplitter(DocumentSplitters.recursive(500, 0))
                .embeddingModel(embeddingModel)
                .embeddingStore(embeddingStore)
                .build();
        log.info("Ingesting Spring Boot Resources ...");
        ingestor.ingest(documents);
        log.info("Ingested {} documents", documents.size());
        EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel);
        EmbeddingStoreLoggingRetriever loggingRetriever = new EmbeddingStoreLoggingRetriever(retriever);
        /*MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
                .maxMessages(10)
                .build();*/
        log.info("Building ConversationalRetrievalChain ...");
        ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
                .chatLanguageModel(OpenAiChatModel.builder()
                        .apiKey(apiKey)
                        .timeout(Duration.ofSeconds(timeout))
                        .build()
                )
                .promptTemplate(PromptTemplate.from(PROMPT_TEMPLATE_2))
                //.chatMemory(chatMemory)
                .retriever(loggingRetriever)
                .build();
        log.info("Spring Boot knowledge base is ready!");
        return chain;
    }
}

application.yaml

langchain:
  api:
    # "demo" is a free API key for testing purposes only. Please replace it with your own API key.
    key: demo
    # key: OPEN_API_KEY
  # API call to complete before it is timed out.
  timeout: 30

index.html

<!DOCTYPE html>
<html lang="en"
      xmlns="http://www.w3.org/1999/xhtml">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>Spring Boot Doc Bot</title>
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="external nofollow"  rel="stylesheet">
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css" rel="external nofollow" >
</head>
<body>
<nav class="bg-dark text-white py-3">
    <div class="text-center d-flex justify-content-center align-items-center">
        <img src="/logo.png" alt="Logo" style="width:60px; margin-right: 10px;">
        <h2 style="margin: 0;">Welcome to Spring Boot Documentation Bot</h2>
    </div>
</nav>
<div class="container mt-5">
    <div class="row">
        <div class="col-md-8 offset-2">
            <h3 class="text-center mb-3">Ask your Spring related queries here!</h3>
            <form>
                <div class="mb-3">
                    <label for="questionInput" class="form-label">Question</label>
                    <input type="text" class="form-control" id="questionInput" name="question" placeholder="Enter your question" required>
                </div>
                <div class="mb-3 text-center">
                    <button id="submitBtn" type="button" class="btn btn-primary">Ask!</button>
                    <button id="clearBtn" type="button" class="btn btn-secondary">Clear</button>
                </div>
            </form>
        </div>
    </div>
    <div class="row my-5">
        <div class="col-md-8 offset-md-2">
            <label for="answerBox" class="form-label"><h5>Answer</h5></label>
            <div class="position-relative my-3">
                <textarea class="form-control" rows="10" id="answerBox" disabled></textarea>
                <a href="#" rel="external nofollow"  class="position-absolute top-0 end-0 m-2" id="copyBtn">
                    <i class="far fa-copy"></i>
                </a>
            </div>
        </div>
    </div>
</div>
<script src="https://code.jquery.com/jquery-3.7.1.min.js"></script>
<script>
    $(document).ready(function () {
        $("#submitBtn").click(function () {
            let questionValue = $("#questionInput").val();
            if (!questionValue) {
                alert('Please enter your question');
                return;
            }
            $("#answerBox").val('Please wait... fetching answer...');
            $.ajax({
                type: "POST",
                url: "/ask",
                data: JSON.stringify({ question: $("#questionInput").val() }),
                //contentType: "application/json; charset=utf-8",
                dataType: "text",
                success: function (data) {
                    //console.log(typeof data);
                    //console.log(data);
                    $("#answerBox").val(data);
                },
                error: function (errMsg) {
                    alert(errMsg);
                }
            });
        });
        $("#clearBtn").click(function () {
            $("#questionInput").val('');
            $("#answerBox").val('');
        });
        document.getElementById("copyBtn").addEventListener("click", function() {
            var copyText = document.getElementById("answerBox");
            copyText.select();
            copyText.setSelectionRange(0, 99999);
            document.execCommand("copy");
            alert("Copied: " + copyText.value);
        });
    });
</script>
</body>
</html>

只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

GitHub - Harries/springboot-demo: a simple springboot demo with some components for example: redis,solr,rockmq and so on.(dag)

4.测试

启动Spring Boot应用,访问http://127.0.0.1:8080/

5.引用

GitHub - miliariadnane/spring-boot-doc-rag-bot: This repository contains a documentation bot powered by an LLM using @langchain4j to swiftly find answers to your Spring Boot questions. It provides easy browsing of Spring documentation and leverages the RAG technique to retrieve relevant details on demand.

什么是 RAG?— 检索增强生成 AI 详解 — AWS

GitHub - liaokongVFX/LangChain-Chinese-Getting-Started-Guide: LangChain 的中文入门教程

Spring Boot集成LangChain来实现Rag应用 | Harries Blog™

到此这篇关于Spring Boot集成LangChain来实现Rag应用的文章就介绍到这了,更多相关Spring Boot集成LangChain Rag应用内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

阅读全文