暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

golang源码分析:langchaingo(5)

        前面介绍的都是无状态的单词请求,如果希望连续聊天,并且AI能根据历史的聊天记录给出相关联的回答,怎么做呢?看下面的例子:
        memoryBuffer := memory.NewConversationWindowBuffer(10)
        chatChain := chains.NewConversation(llm, memoryBuffer)
        messages := []string{
            "你好,我叫PBR",
            "你知道我叫什么吗?",
            "你可以解决什么问题?",
        }
        for _, message := range messages {
            completion, err := chains.Run(ctx, chatChain, message)
            for {
                if err == nil {
                    break
                }
                time.Sleep(30 * time.Second)
                completion, err = chains.Run(ctx, chatChain, message)
            }
            chatMessages, _ := memoryBuffer.ChatHistory.Messages(ctx)
            fmt.Printf("上下文对话历史:%v\n", chatMessages)
            fmt.Printf("输入:%v\n输出:%v\n", message, completion)
    可以看到,我们定义了NewConversationWindowBuffer:带聊天记录条数窗口的buffer,当然也可以根据实际情况定义更多类型的buffer:简单buffer或者指定token长度的buffer
      memoryBuffer := memory.NewConversationBuffer()
      memoryBuffer := memory.NewConversationWindowBuffer(10)
      memoryBuffer := memory.NewConversationTokenBuffer(llm, 1024)  
              github.com/tmc/langchaingo@v0.1.13/memory/window_buffer.go
        func NewConversationWindowBuffer(
            conversationWindowSize int,
            options ...ConversationBufferOption,
        ) *ConversationWindowBuffer {
            if conversationWindowSize <= 0 {
                conversationWindowSize = defaultConversationWindowSize
            }
            tb := &ConversationWindowBuffer{
                ConversationWindowSize: conversationWindowSize,
                ConversationBuffer:     *applyBufferOptions(options...),
            }
            return tb
        }
          type ConversationWindowBuffer struct {
              ConversationBuffer
              ConversationWindowSize int
          }
          继承自github.com/tmc/langchaingo@v0.1.13/memory/buffer.go
            type ConversationBuffer struct {
                ChatHistory schema.ChatMessageHistory
                ReturnMessages bool
                InputKey       string
                OutputKey      string
                HumanPrefix    string
                AIPrefix       string
                MemoryKey      string
            }
                    其中聊天历史的定义是一个接口,支持添加普通消息、用户消息还有AI响应消息:
            github.com/tmc/langchaingo@v0.1.13/schema/chat_message_history.go
              // ChatMessageHistory is the interface for chat history in memory/store.
              type ChatMessageHistory interface {
                  // AddMessage adds a message to the store.
                  AddMessage(ctx context.Context, message llms.ChatMessage) error
                  // AddUserMessage is a convenience method for adding a human message string
                  // to the store.
                  AddUserMessage(ctx context.Context, message stringerror
                  // AddAIMessage is a convenience method for adding an AI message string to
                  // the store.
                  AddAIMessage(ctx context.Context, message stringerror
                  // Clear removes all messages from the store.
                  Clear(ctx context.Context) error
                  // Messages retrieves all messages from the store
                  Messages(ctx context.Context) ([]llms.ChatMessage, error)
                  // SetMessages replaces existing messages in the store
                  SetMessages(ctx context.Context, messages []llms.ChatMessage) error
              }
                      另外两个的定义如下:
                func applyBufferOptions(opts ...ConversationBufferOption) *ConversationBuffer {
                    m := &ConversationBuffer{
                        ReturnMessages: false,
                        InputKey:       "",
                        OutputKey:      "",
                        HumanPrefix:    "Human",
                        AIPrefix:       "AI",
                        MemoryKey:      "history",
                    }
                    for _, opt := range opts {
                        opt(m)
                    }
                    if m.ChatHistory == nil {
                        m.ChatHistory = NewChatMessageHistory()
                    }
                    return m
                }
                  func NewConversationTokenBuffer(
                      llm llms.Model,
                      maxTokenLimit int,
                      options ...ConversationBufferOption,
                  ) *ConversationTokenBuffer {
                      tb := &ConversationTokenBuffer{
                          LLM:                llm,
                          MaxTokenLimit:      maxTokenLimit,
                          ConversationBuffer: *applyBufferOptions(options...),
                      }
                      return tb
                  }
                  区别是后者添加了token数限制。
                          定义完记忆缓冲区后,接着就是初始化聊天会话
                    func NewConversation(llm llms.Model, memory schema.Memory) LLMChain {
                        return LLMChain{
                            Prompt: prompts.NewPromptTemplate(
                                _conversationTemplate,
                                []string{"history""input"},
                            ),
                            LLM:          llm,
                            Memory:       memory,
                            OutputParser: outputparser.NewSimple(),
                            OutputKey:    _llmChainDefaultOutputKey,
                        }
                    }
                      const _llmChainDefaultOutputKey = "text"
                              可以看到,它定义了一个提示词模板,模板里面有history和input两个参数,具体模板内容如下:
                        const _conversationTemplate = `The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
                        Current conversation:
                        {{.history}}
                        Human: {{.input}}
                        AI:`
                        而LLMChain的定义如下:
                          type LLMChain struct {
                              Prompt           prompts.FormatPrompter
                              LLM              llms.Model
                              Memory           schema.Memory
                              CallbacksHandler callbacks.Handler
                              OutputParser     schema.OutputParser[any]
                              OutputKey string
                          }
                          最后通过llmchain的Run方法获取返回结果:
                            func Run(ctx context.Context, c Chain, input any, options ...ChainCallOption) (stringerror) {
                                inputKeys := c.GetInputKeys()
                                memoryKeys := c.GetMemory().MemoryVariables(ctx)
                                neededKeys := make([]string0len(inputKeys))
                                // Remove keys gotten from the memory.
                                for _, inputKey := range inputKeys {
                                    isInMemory := false
                                    for _, memoryKey := range memoryKeys {
                                        if inputKey == memoryKey {
                                            isInMemory = true
                                            continue
                                        }
                                    }
                                    if isInMemory {
                                        continue
                                    }
                                    neededKeys = append(neededKeys, inputKey)
                                }
                                if len(neededKeys) != 1 {
                                    return "", ErrMultipleInputsInRun
                                }
                                outputKeys := c.GetOutputKeys()
                                if len(outputKeys) != 1 {
                                    return "", ErrMultipleOutputsInRun
                                }
                                inputValues := map[string]any{neededKeys[0]: input}
                                outputValues, err := Call(ctx, c, inputValues, options...)
                                if err != nil {
                                    return "", err
                                }
                                outputValue, ok := outputValues[outputKeys[0]].(string)
                                if !ok {
                                    return "", ErrWrongOutputTypeInRun
                                }
                                return outputValue, nil
                            }
                            在Run内部调用了Call方法,它先从memory里面取历史记录,然后根据选项筛选历史记录,接着调用LLM获取返回结果,并将结果记录到memory里面给下次使用
                              // Call is the standard function used for executing chains.
                              func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
                                  fullValues := make(map[string]any, 0)
                                  for key, value := range inputValues {
                                      fullValues[key] = value
                                  }
                                  newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
                                  if err != nil {
                                      return nil, err
                                  }
                                  for key, value := range newValues {
                                      fullValues[key] = value
                                  }
                                  callbacksHandler := getChainCallbackHandler(c)
                                  if callbacksHandler != nil {
                                      callbacksHandler.HandleChainStart(ctx, inputValues)
                                  }
                                  outputValues, err := callChain(ctx, c, fullValues, options...)
                                  if err != nil {
                                      if callbacksHandler != nil {
                                          callbacksHandler.HandleChainError(ctx, err)
                                      }
                                      return outputValues, err
                                  }
                                  if callbacksHandler != nil {
                                      callbacksHandler.HandleChainEnd(ctx, outputValues)
                                  }
                                  if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
                                      return outputValues, err
                                  }
                                  return outputValues, nil
                              }
                              其中的callChain就是调用chain中的每个元素的call方法
                                func callChain(
                                    ctx context.Context,
                                    c Chain,
                                    fullValues map[string]any,
                                    options ...ChainCallOption,
                                ) (map[string]any, error) {
                                    if err := validateInputs(c, fullValues); err != nil {
                                        return nil, err
                                    }
                                    outputValues, err := c.Call(ctx, fullValues, options...)
                                    if err != nil {
                                        return outputValues, err
                                    }
                                    if err := validateOutputs(c, outputValues); err != nil {
                                        return outputValues, err
                                    }
                                    return outputValues, nil
                                }
                                        至此历史记录的相关源码介绍完毕。

                                文章转载自golang算法架构leetcode技术php,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

                                评论