暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

golang源码分析:langchaingo(6)

        前面介绍langchaingo都是简单应用没有聊到它的核心处理流程,链式处理,这里还是结合例子详细分析下它的源码:
    // 将输入翻译为特定语言
        chain1 := chains.NewLLMChain(llm,
            prompts.NewPromptTemplate(
                "请将输入的原始文本:{{.originText}}翻译为{{.language}},直接输出翻译文本",
                []string{"originText""language"}))
        chain1.OutputKey = "transText"
        // 总结翻译后的文本概要
        chain2 := chains.NewLLMChain(llm, prompts.NewPromptTemplate(
            "请将输入的原始文本:<{{.transText}}>总结50字以内概要文本。严格使用JSON序列化输出结果,不要带有```json序列化标识。其中originText为原始文本,summaryText为概要文本",
            []string{"transText"}))
        chain2.OutputKey = "summary_json"
        chain, err := chains.NewSequentialChain([]chains.Chain{chain1, chain2}, []string{"originText""language"}, []string{"summary_json"})
        if err != nil {
            log.Fatal(err)
        }
        resp, err := chain.Call(ctx, map[string]any{
            "originText""langchain is a good llm frameworks",
            "language":   "中文",
        })
            可以看到,先定义了两个chain,然后用NewSequentialChain将它俩组合起来,最后调用Call方法,可以看到,虽然模仿了langchain的链式方案,但是用起来没有python的|符号重载直接简单。这里还是依次介绍下源码。
            github.com/tmc/langchaingo@v0.1.13/chains/llm.go里定义了chain中的每个节点
      // NewLLMChain creates a new LLMChain with an LLM and a prompt.
      func NewLLMChain(llm llms.Model, prompt prompts.FormatPrompter, opts ...ChainCallOption) *LLMChain {
          opt := &chainCallOption{}
          for _, o := range opts {
              o(opt)
          }
          chain := &LLMChain{
              Prompt:           prompt,
              LLM:              llm,
              OutputParser:     outputparser.NewSimple(),
              Memory:           memory.NewSimple(),
              OutputKey:        _llmChainDefaultOutputKey,
              CallbacksHandler: opt.CallbackHandler,
          }
          return chain
      }
              它包括了提示词、llm、缓存、输出解析器、输出key等内容
        type LLMChain struct {
            Prompt           prompts.FormatPrompter
            LLM              llms.Model
            Memory           schema.Memory
            CallbacksHandler callbacks.Handler
            OutputParser     schema.OutputParser[any]
            OutputKey string
        }
                可以看到,它只有一个输出key,意味着只能有一个输出值。参数里有提示词模板和参数列表
          func NewPromptTemplate(template string, inputVars []string) PromptTemplate {
              return PromptTemplate{
                  Template:       template,
                  InputVariables: inputVars,
                  TemplateFormat: TemplateFormatGoTemplate,
              }
          }
            // PromptTemplate contains common fields for all prompt templates.
            type PromptTemplate struct {
                // Template is the prompt template.
                Template string
                // A list of variable names the prompt template expects.
                InputVariables []string
                // TemplateFormat is the format of the prompt template.
                TemplateFormat TemplateFormat
                // OutputParser is a function that parses the output of the prompt template.
                OutputParser schema.OutputParser[any]
                // PartialVariables represents a map of variable names to values or functions
                // that return values. If the value is a function, it will be called when the
                // prompt template is rendered.
                PartialVariables map[string]any
            }
                    可以看到,第二个chain的输入使用了第一个chain的输出,整个链就是这么串起来的。接着看下串起整个链的逻辑
              func NewSequentialChain(chains []Chain, inputKeys []string, outputKeys []string, opts ...SequentialChainOption) (*SequentialChain, error) { //nolint:lll
                  s := &SequentialChain{
                      chains:     chains,
                      inputKeys:  inputKeys,
                      outputKeys: outputKeys,
                      memory:     memory.NewSimple(),
                  }
                  for _, opt := range opts {
                      opt(s)
                  }
                  if err := s.validateSeqChain(); err != nil {
                      return nil, err
                  }
                  return s, nil
              }
                      入参传入了所有chain的列表,所有输入参数列表和最终输出参数名称。函数内部先把参数存到结构体里,后面校验了下参数和整个链的完整合法性
                type SequentialChain struct {
                    chains     []Chain
                    inputKeys  []string
                    outputKeys []string
                    memory     schema.Memory
                }
                        首先输入的参数和内存中存在的参数不能有交集,即后面链中加入的参数不能和前面节点的参数重名
                  func (c *SequentialChain) validateSeqChain() error {
                      knownKeys := setutil.ToSet(c.inputKeys)
                      // Make sure memory keys don't collide with input keys
                      memoryKeys := c.memory.MemoryVariables(context.Background())
                      overlappingKeys := setutil.Intersection(memoryKeys, knownKeys)
                          接着校验链中输入的参数必须是前面链中的输出
                        for i, c := range c.chains {
                            // Check that chain has input keys that are in knownKeys
                            missingKeys := setutil.Difference(c.GetInputKeys(), knownKeys)
                            if len(missingKeys) > 0 {
                            
                        overlappingKeys := setutil.Intersection(c.GetOutputKeys(), knownKeys)
                            if len(overlappingKeys) > 0 {
                            最后校验了输出参数
                          // Check that outputKeys are in knownKeys
                          for _, key := range c.outputKeys {
                              if _, ok := knownKeys[key]; !ok {
                                准备工作完成后就到了Call函数调用的阶段
                        func (c *SequentialChain) Call(ctx context.Context, inputs map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint:lll
                            var outputs map[string]any
                            var err error
                            for _, chain := range c.chains {
                                outputs, err = Call(ctx, chain, inputs, options...)
                                if err != nil {
                                    return nil, err
                                }
                                // Set the input for the next chain to the output of the current chain
                                inputs = outputs
                            }
                            return outputs, nil
                        }
                                    可以看到,其实就是一个for循环调用每个Chain,将前一个输出作为后一个输入而已,里面调用的Call函数就是前面介绍的单次请求里的call函数。
                          func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
                              fullValues := make(map[string]any, 0)
                              for key, value := range inputValues {
                                  fullValues[key] = value
                              }
                              newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
                              if err != nil {
                                  return nil, err
                              }
                              for key, value := range newValues {
                                  fullValues[key] = value
                              }
                              callbacksHandler := getChainCallbackHandler(c)
                              if callbacksHandler != nil {
                                  callbacksHandler.HandleChainStart(ctx, inputValues)
                              }
                              outputValues, err := callChain(ctx, c, fullValues, options...)
                              if err != nil {
                                  if callbacksHandler != nil {
                                      callbacksHandler.HandleChainError(ctx, err)
                                  }
                                  return outputValues, err
                              }
                              if callbacksHandler != nil {
                                  callbacksHandler.HandleChainEnd(ctx, outputValues)
                              }
                              if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
                                  return outputValues, err
                              }
                              return outputValues, nil
                          }

                          文章转载自golang算法架构leetcode技术php,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

                          评论