LLM

langchain入门-使用langchain构建一个拥有RAG能力的代码问答应用(DeepSeek API&本地bge-m3&命令行版本)

详细介绍如何使用 Langchain 构建拥有 RAG 能力的代码问答应用,包括 DeepSeek API 集成、本地 bge-m3 嵌入模型配置、以及命令行版本实现。

说在前面

  • 操作系统:windows
  • python版本:3.9
  • langchain版本:0.3.20
  • pycharm版本:2023.1.2 (Community Edition)

embedding模型

  • 什么是embedding模型?以下回答由AI提供

    什么是Embedding?‌

    • Embedding(嵌入)是将数据映射到低维向量空间的过程。例如,一个单词、句子或图片可以被表示为稠密向量(如300维),而非原始的稀疏高维形式(如one-hot编码)。向量空间中,语义相近的对象(如“猫”和“狗”)距离更近,而无关的对象(如“猫”和“飞机”)距离更远。
    • 关键特点‌: ‌低维稠密‌:相比one-hot编码,维度更低且每个维度包含语义信息。 ‌保留关系‌:捕捉数据中的语义、语法或特征关联(如“国王-男人+女人≈女王”)。 ‌可计算性‌:向量支持数学运算(如余弦相似度),便于下游任务处理。
  • 我们将输入数据转换成向量之后,就可以更方便的查询这些数据之间的关系。
  • 在魔塔以及huggingface上,选择句子相似度(Sentence Similarity) 可以看到bge-m3目前的排名还挺高,这里我们就选择它
  • 部署方式使用ollama(使用llama.cpp部署后似乎调不通),直接执行:
    ollama pull bge-m3
    
  • langchain中调用
    from langchain_ollama import OllamaEmbeddings
    
    embeddings = OllamaEmbeddings(model="bge-m3:latest")
    

解析代码

  • 准备工作
    pip install unstructured
    
    下载averaged_perceptron_tagger_eng
    import nltk
    nltk.download('averaged_perceptron_tagger_eng')
    
  • 加载文件
    from langchain_community.document_loaders import DirectoryLoader
    from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
    
    # 加载所有go文件
    loader = DirectoryLoader(path="../detour-go-main", glob="*.go", recursive=True)
    docs = loader.load()
    
  • 分割代码
    go_splitter = RecursiveCharacterTextSplitter.from_language(language=Language.GO, chunk_overlap=0)
    all_splits = go_splitter.split_documents(documents=docs)
    
  • 定义向量存储,在本文中,使用用于测试的InMemoryVectorStore
    # 向量存储
    from langchain_core.vectorstores import InMemoryVectorStore
    
    vector_store = InMemoryVectorStore(embeddings)
    
  • 将数据通过embedding模型转换成向量,并存储
    # 将代码转换成向量
    _ = vector_store.add_documents(documents=all_splits)
    

创建检索工具

  • 为了让模型能够有能力访问我们的数据,我们需要定义一个工具(或者说方法)让模型去调用。
    from langchain_core.tools import tool
    
    @tool(response_format="content_and_artifact", description="rag")
    def retrieve(query: str):
        # 根据相似度查询数据
        retrieve_docs = vector_store.similarity_search(query, k=2)
        serialized = "\n\n".join(
            (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
            for doc in retrieve_docs
        )
        return serialized, retrieve_docs
    
    # 可以先测试下效果
    print(retrieve.invoke("DtObstacleAvoidanceQuery.sampleVelocityAdaptive的作用是什么?"))
    

构建graph

  • 和上一篇一样,创建一个StateGraph
    from langgraph.graph import MessagesState, StateGraph
    
    graph_builder = StateGraph(state_schema=MessagesState)
    
  • 定义query_or_response节点 这个节点的作用是,让AI根据历史消息决定是否需要调用工具
    from langchain_core.messages import SystemMessage
    
    def query_or_response(state: MessagesState):
        llm_with_tools = llm.bind_tools([retrieve])
        response = llm_with_tools.invoke(state["messages"])
        return {"messages": [response]}
    
    以下是一个模型调用返回:
    AIMessage(content='', 
    additional_kwargs={
        'tool_calls': [
            {
                'id': 'xxx', 
                'function': {
                    'arguments': '{"query":"sampleVelocityAdaptive的作用"}', 
                    'name': 'retrieve'
                }, 
                'type': 'function', 
                'index': 0
            }], 
        'refusal': None}, 
    response_metadata={xxx}, 
    id='xxx',
    tool_calls=[{
        'name': 'retrieve', 
        'args': {'query': 'sampleVelocityAdaptive的作用'}, 
        'id': 'xxx', 
        'type': 'tool_call'}],
    usage_metadata={xxx})
    
    可以看到在返回数据中additional_kwargs带上了tool_calls的信息,通过这些信息,我们的程序就知道应该使用什么参数,调用哪个方法。
  • 定义tool节点
    from langgraph.prebuilt import ToolNode
    tools = ToolNode([retrieve])
    
  • 定义generate节点,这个节点就是正常的合并历史消息,调用模型;不过在处理ToolMessage时,将其分离出来并放在了prompt
    def generate(state: MessagesState):
        recent_tool_messages = []
        for message in reversed(state["messages"]):
            if message.type == "tool":
                recent_tool_messages.append(message)
            else:
                break
        tools_messages = recent_tool_messages[::-1]
        docs_content = "\n\n".join(doc.content for doc in tools_messages)
        sys_msg_content = (
            "你是一个go语言的代码助手."
            "你可以使用下面的检索数据来回答用户的问题."
            "如果你不知道如何回答,请回答不知道."
            "\n\n"
            f"{docs_content}"
        )
        conversation_msgs = [
            msg
            for msg in state["messages"]
            if msg.type in ("human", "system")
               or (msg.type == "ai" and not msg.tool_calls)
        ]
        prompt = [SystemMessage(sys_msg_content)] + conversation_msgs
    
        resp = llm.invoke(prompt)
        return {"messages": [resp]}
    
  • 连接节点
    from langgraph.graph import END
    from langgraph.prebuilt import tools_condition
    
    graph_builder.add_node(query_or_response)
    graph_builder.add_node(tools)
    graph_builder.add_node(generate)
    
    graph_builder.set_entry_point("query_or_response")
    graph_builder.add_conditional_edges(
        "query_or_response",
        tools_condition,
        {END: END, "tools": "tools"}
    )
    graph_builder.add_edge("tools", "generate")
    graph_builder.add_edge("generate", END)
    
    compiled_graph = graph_builder.compile()
    
    在该graph中,调用query_or_response后,如果是非tool_call的返回,即正常的对话(比如用户说了句"你好"),那么就会跳转到END节点;否则的话,会调用tools,并再次调用模型
    mermaid
    graph TB
      	A([Start]) --> B[query_or_response]
        B --> E[END]
      	B --> C[tools]
      	C --> D[generate]
      	D --> E[END]

测试

  • 依旧是命令行的方式
    compiled_graph = graph_builder.compile()
    
    from langchain_core.messages import HumanMessage, AIMessage
    
    while True:
        input_msg = input("> ")
        ai_msg = compiled_graph.invoke({"messages": [HumanMessage(content=input_msg)]})
        print(ai_msg)
    
  • 但是测试结果不太行,因为根据相似度检索出来的数据是错误的🥲
    > DtObstacleAvoidanceQuery.sampleVelocityAdaptive的作用是什么?
    根据提供的检索数据,没有找关于 `DtObstacleAvoidanceQuery.sampleVelocityAdaptive` 的具体信息。
    因此,我无法回答这个函数的作用。\n\n如果你有更多关于这个函数的上下文或其他相关代码片段,可以提供给我,我会尽力帮助你解答。
    
  • 整个流程是通的,如果使用一些文档作为数据源可能效果会好点
  • 对于代码,后续还会研究更好的方式来处理

完整代码

from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(model="bge-m3:latest")

from langchain_deepseek import ChatDeepSeek

llm = ChatDeepSeek(model="deepseek-chat", api_key="xxxx")

from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language

# 加载所有go文件
loader = DirectoryLoader(path="../detour-go-main", glob="*.go", recursive=True)
docs = loader.load()

# 分割代码
go_splitter = RecursiveCharacterTextSplitter.from_language(language=Language.GO, chunk_overlap=0)
all_splits = go_splitter.split_documents(documents=docs)

print(all_splits[:3])

# 向量存储
from langchain_core.vectorstores import InMemoryVectorStore

vector_store = InMemoryVectorStore(embeddings)

# 将代码转换成向量
_ = vector_store.add_documents(documents=all_splits)

# 创建一个查询工具
from langchain_core.tools import tool


@tool(response_format="content_and_artifact", description="rag")
def retrieve(query: str):
    # 根据相似度查询数据
    retrieve_docs = vector_store.similarity_search(query, k=2)
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
        for doc in retrieve_docs
    )
    return serialized, retrieve_docs

print(retrieve.invoke("DtObstacleAvoidanceQuery.sampleVelocityAdaptive的作用是什么?"))

# 开始构建graph
from langgraph.graph import MessagesState, StateGraph

graph_builder = StateGraph(state_schema=MessagesState)

from langchain_core.messages import SystemMessage

def query_or_response(state: MessagesState):
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}


from langgraph.prebuilt import ToolNode

tools = ToolNode([retrieve])


def generate(state: MessagesState):
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tools_messages = recent_tool_messages[::-1]
    docs_content = "\n\n".join(doc.content for doc in tools_messages)
    sys_msg_content = (
        "你是一个go语言的代码助手."
        "你可以使用下面的检索数据来回答用户的问题."
        "如果你不知道如何回答,请回答不知道."
        "\n\n"
        f"{docs_content}"
    )
    conversation_msgs = [
        msg
        for msg in state["messages"]
        if msg.type in ("human", "system")
           or (msg.type == "ai" and not msg.tool_calls)
    ]
    prompt = [SystemMessage(sys_msg_content)] + conversation_msgs

    resp = llm.invoke(prompt)
    return {"messages": [resp]}


from langgraph.graph import END
from langgraph.prebuilt import tools_condition

graph_builder.add_node(query_or_response)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

graph_builder.set_entry_point("query_or_response")
graph_builder.add_conditional_edges(
    "query_or_response",
    tools_condition,
    {END: END, "tools": "tools"}
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

compiled_graph = graph_builder.compile()

from langchain_core.messages import HumanMessage, AIMessage

while True:
    input_msg = input("> ")
    ai_msg = compiled_graph.invoke({"messages": [HumanMessage(content=input_msg)]})
    print(ai_msg)