暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

AI大模型应用开发入门-LangChain实现RAG检索增强生成

chester技术分享 2025-06-16
135

检索增强生成(RAG)是一种结合“向量检索”与“大语言模型”的技术路线,能在问答、摘要、文档分析等场景中大幅提升准确性与上下文利用率。

本文将基于 LangChain 构建一个完整的 RAG 流程,结合 PGVector
 作为向量数据库,并用 LangGraph
 构建状态图控制流程。


大语言模型初始化(llm_env.py)

我们首先使用 LangChain 提供的模型初始化器加载 gpt-4o-mini
 模型,供后续问答使用。

    # llm_env.py
    from langchain.chat_models import init_chat_model
    llm = init_chat_model("gpt-4o-mini", model_provider="openai"

    RAG 主体流程(rag.py)

    以下是整个 RAG 系统的主流程代码,主要包括:文档加载与切分、向量存储、状态图建模(analyze→retrieve→generate)、交互式问答。

      # rag.py
      import os
      import sys
      import time
      sys.path.append(os.getcwd())
      from llm_set import llm_env
      from langchain_openai import OpenAIEmbeddings
      from langchain_postgres import PGVector
      from langchain_community.document_loaders import WebBaseLoader
      from langchain_core.documents import Document
      from langchain_text_splitters import RecursiveCharacterTextSplitter
      from langgraph.graph import START, StateGraph
      from typing_extensions import List, TypedDict, Annotated
      from typing import Literal
      from langgraph.checkpoint.postgres import PostgresSaver
      from langgraph.graph.message import add_messages
      from langchain_core.messages import HumanMessage, BaseMessage
      from langchain_core.prompts import ChatPromptTemplate
      # 初始化 LLM
      llm = llm_env.llm
      # 嵌入模型
      embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
      # 向量数据库初始化
      vector_store = PGVector(
          embeddings=embeddings,
          collection_name="my_rag_docs",
          connection="postgresql+psycopg2://postgres:123456@localhost:5433/langchainvector",
      )
      # 加载网页内容
      url = "https://python.langchain.com/docs/tutorials/qa_chat_history/"
      loader = WebBaseLoader(web_paths=(url,))
      docs = loader.load()
      for doc in docs:
          doc.metadata["source"] = url
      # 文本分割
      text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
      all_splits = text_splitter.split_documents(docs)
      # 添加 section 元数据
      total_documents = len(all_splits)
      third = total_documents / 3
      for i, document in enumerate(all_splits):
          if i < third:
              document.metadata["section"] = "beginning"
          elif i < 2 * third:
              document.metadata["section"] = "middle"
          else:
              document.metadata["section"] = "end"
      # 检查是否已存在向量
      existing = vector_store.similarity_search(url, k=1filter={"source": url})
      if not existing:
          _ = vector_store.add_documents(documents=all_splits)
          print("文档向量化完成"

      分析、检索与生成模块

      接下来,我们定义三个函数构成 LangGraph
       的流程:analyze → retrieve → generate。

        class Search(TypedDict):
            query: Annotated[str"The question to be answered"]
            section: Annotated[
                Literal["beginning""middle""end"],
                ...,
                "Section to query.",
            ]
        class State(TypedDict):
            messages: Annotated[list[BaseMessage], add_messages]
            query: Search
            context: List[Document]
            answer: set
        # 分析意图 → 获取 query 与 section
        def analyze(state: State):
            structtured_llm = llm.with_structured_output(Search)
            query = structtured_llm.invoke(state["messages"])
            return {"query": query}
        # 相似度检索
        def retrieve(state: State):
            query = state["query"]
            if hasattr(query, 'section'):
                filter = {"section": query["section"]}
            else:
                filter = None
            retrieved_docs = vector_store.similarity_search(query["query"], filter=filter)
            return {"context": retrieved_docs}
        生成模块基于 
        ChatPromptTemplate
         和当前上下文生成回答:
        prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system""尽你所能按照上下文:{context},回答问题:{question}。"),
            ]
        )
        def generate(state: State):
            docs_content = "\n\n".join(doc.page_content for doc in state["context"])
            messages = prompt_template.invoke({
                "question": state["query"]["query"],
                "context": docs_content,
            })
            response = llm.invoke(messages)
            return {"answer": response.content, "messages": [response]} 

        构建 LangGraph 流程图

        定义好状态结构后,我们构建 LangGraph

          graph_builder = StateGraph(State).add_sequence([analyze, retrieve, generate])
          graph_builder.add_edge(START, "analyze"

          PG 数据库中保存中间状态(Checkpoint)

          我们通过 PostgresSaver
           记录每次对话的中间状态:

            DB_URI = "postgresql://postgres:123456@localhost:5433/langchaindemo?sslmode=disable"
            with PostgresSaver.from_conn_string(DB_URI) as checkpointer:
                checkpointer.setup()
                graph = graph_builder.compile(checkpointer=checkpointer)
                input_thread_id = input("输入thread_id:")
                time_str = time.strftime("%Y%m%d", time.localtime())
                config = {"configurable": {"thread_id"f"rag-{time_str}-demo-{input_thread_id}"}}
                print("输入问题,输入 exit 退出。")
                while True:
                    query = input("你: ")
                    if query.strip().lower() == "exit":
                        break
                    input_messages = [HumanMessage(query)]
                    response = graph.invoke({"messages": input_messages}, config=config)
                    print(response["answer"]) 

            效果


            总结

            本文通过 LangChain 的模块式能力,结合 PGVector 向量库与 LangGraph 有状态控制系统,实现了一个可交互、可持久化、支持多文档结构的 RAG 系统。其优势包括:

            • 支持结构化提问理解(分区查询)

            • 自动化分段与元数据标记

            • 状态流追踪与恢复

            • 可拓展支持文档上传、缓存优化、多用户配置

            关注获取技术分享

            文章转载自chester技术分享,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

            评论