暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

MCP源码分析:stdio

        我们实现一个基于stdio的MCP server,分为下面四步
1,创建一个MCP server
        s := server.NewMCPServer(
            "ip-mcp",
            "1.0.0",
        )
    2,创建一个工具描述信息
          tool := mcp.NewTool("ip_query",
              mcp.WithDescription("query geo location of an IP address"),
              mcp.WithString("ip",
                  mcp.Required(),
                  mcp.Description("IP address to query"),
              ),
          )
      3,将工具的实现和描述信息绑定到server上
        s.AddTool(tool, ipQueryHandler)
        4,启动server来接收请求
              if err := server.ServeStdio(s); err != nil {
                  我们先看看创建server的代码,它的实现位于:github.com/mark3labs/mcp-go@v0.27.0/server/server.go ,除了name和version两个字段外,其他字段都是可选字段:
            func NewMCPServer(
                name, version string,
                opts ...ServerOption,
            ) *MCPServer {
                s := &MCPServer{
                    resources:            make(map[string]resourceEntry),
                    resourceTemplates:    make(map[string]resourceTemplateEntry),
                    prompts:              make(map[string]mcp.Prompt),
                    promptHandlers:       make(map[string]PromptHandlerFunc),
                    tools:                make(map[string]ServerTool),
                    name:                 name,
                    version:              version,
                    notificationHandlers: make(map[string]NotificationHandlerFunc),
                    capabilities: serverCapabilities{
                        tools:     nil,
                        resources: nil,
                        prompts:   nil,
                        logging:   false,
                    },
                }
                for _, opt := range opts {
                    opt(s)
                }
                return s
            }
            s e rver上用map存储了resources,resourceTemplates,promptspromptHandlers和tools。具体结构体定义如下:
              type MCPServer struct {
                  // Separate mutexes for different resource types
                  resourcesMu            sync.RWMutex
                  promptsMu              sync.RWMutex
                  toolsMu                sync.RWMutex
                  middlewareMu           sync.RWMutex
                  notificationHandlersMu sync.RWMutex
                  capabilitiesMu         sync.RWMutex
                  toolFiltersMu          sync.RWMutex
                  name                   string
                  version                string
                  instructions           string
                  resources              map[string]resourceEntry
                  resourceTemplates      map[string]resourceTemplateEntry
                  prompts                map[string]mcp.Prompt
                  promptHandlers         map[string]PromptHandlerFunc
                  tools                  map[string]ServerTool
                  toolHandlerMiddlewares []ToolHandlerMiddleware
                  toolFilters            []ToolFilterFunc
                  notificationHandlers   map[string]NotificationHandlerFunc
                  capabilities           serverCapabilities
                  paginationLimit        *int
                  sessions               sync.Map
                  hooks                  *Hooks
              }
              为了方便并发读写,这些map都定义了锁,保证安全。
                      tool的描述信息定义也是类似的,其中name是必传参数
                func NewTool(name string, opts ...ToolOption) Tool {
                    tool := Tool{
                        Name: name,
                        InputSchema: ToolInputSchema{
                            Type:       "object",
                            Properties: make(map[string]any),
                            Required:   nil// Will be omitted from JSON if empty
                        },
                        Annotations: ToolAnnotation{
                            Title:           "",
                            ReadOnlyHint:    ToBoolPtr(false),
                            DestructiveHint: ToBoolPtr(true),
                            IdempotentHint:  ToBoolPtr(false),
                            OpenWorldHint:   ToBoolPtr(true),
                        },
                    }
                其结构体定义中除了名字和描述外还有InputSchema和Annotations
                  type Tool struct {
                      // The name of the tool.
                      Name string `json:"name"`
                      // A human-readable description of the tool.
                      Description string `json:"description,omitempty"`
                      // A JSON Schema object defining the expected parameters for the tool.
                      InputSchema ToolInputSchema `json:"inputSchema"`
                      // Alternative to InputSchema - allows arbitrary JSON Schema to be provided
                      RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling
                      // Optional properties describing tool behavior
                      Annotations ToolAnnotation `json:"annotations"`
                  }
                    type ToolInputSchema struct {
                        Type       string         `json:"type"`
                        Properties map[string]any `json:"properties,omitempty"`
                        Required   []string       `json:"required,omitempty"`
                    }
                    定义了参数的类型,每一个参数的属性值和必传参数列表,整体为参数的json-schema服务的。
                      type ToolAnnotation struct {
                          // Human-readable title for the tool
                          Title string `json:"title,omitempty"`
                          // If true, the tool does not modify its environment
                          ReadOnlyHint *bool `json:"readOnlyHint,omitempty"`
                          // If true, the tool may perform destructive updates
                          DestructiveHint *bool `json:"destructiveHint,omitempty"`
                          // If true, repeated calls with same args have no additional effect
                          IdempotentHint *bool `json:"idempotentHint,omitempty"`
                          // If true, tool interacts with external entities
                          OpenWorldHint *bool `json:"openWorldHint,omitempty"`
                      }
                      定义了一些约束标记,是bool类型的,我们可以通过option来控制。
                              工具绑定的逻辑如下:
                        func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) {
                            s.AddTools(ServerTool{Tool: tool, Handler: handler})
                        }
                        第二个参数类型是ToolHandlerFunc,也就是要求我们实现的工具具体处理函数的输入和输出:
                          type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
                          其中输入定义如下,我们从Arguments这个map中取参数即可
                            type CallToolRequest struct {
                                Request
                                Params struct {
                                    Name      string         `json:"name"`
                                    Arguments map[string]any `json:"arguments,omitempty"`
                                    Meta      *struct {
                                        // If specified, the caller is requesting out-of-band progress
                                        // notifications for this request (as represented by
                                        // notifications/progress). The value of this parameter is an
                                        // opaque token that will be attached to any subsequent
                                        // notifications. The receiver is not obligated to provide these
                                        // notifications.
                                        ProgressToken ProgressToken `json:"progressToken,omitempty"`
                                    } `json:"_meta,omitempty"`
                                } `json:"params"`
                            }
                              type Request struct {
                                  Method string `json:"method"`
                                  Params struct {
                                      Meta *struct {
                                          // If specified, the caller is requesting out-of-band progress
                                          // notifications for this request (as represented by
                                          // notifications/progress). The value of this parameter is an
                                          // opaque token that will be attached to any subsequent
                                          // notifications. The receiver is not obligated to provide these
                                          // notifications.
                                          ProgressToken ProgressToken `json:"progressToken,omitempty"`
                                      } `json:"_meta,omitempty"`
                                  } `json:"params,omitempty"`
                              }
                              输出定义如下,我能将结果Content,append到返回值中就行:
                                type CallToolResult struct {
                                    Result
                                    Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource
                                    // Whether the tool call ended in an error.
                                    //
                                    // If not set, this is assumed to be false (the call was successful).
                                    IsError bool `json:"isError,omitempty"`
                                }
                                  type Result struct {
                                      // This result property is reserved by the protocol to allow clients and
                                      // servers to attach additional metadata to their responses.
                                      Meta map[string]any `json:"_meta,omitempty"`
                                  }
                                          绑定工具,其核心就是把tool存储到map里面去,类似http服务器的路由注册,如果注册了监听事件,会触发事件通知。
                                    func (s *MCPServer) AddTools(tools ...ServerTool) {
                                        s.toolsMu.Lock()
                                        for _, entry := range tools {
                                            s.tools[entry.Tool.Name] = entry
                                        }
                                        s.toolsMu.Unlock()
                                    }
                                            最后是启动server函数,首先会创建一个stdio的server对象,最后监听输入和输出:
                                      func ServeStdio(server *MCPServer, opts ...StdioOption) error {
                                          s := NewStdioServer(server)
                                          return s.Listen(ctx, os.Stdin, os.Stdout)
                                              stdioServer是普通server的一个装饰,它的listen函数里实现了对标准输入的监听,并将结果写入到标准输出
                                        func NewStdioServer(server *MCPServer*StdioServer {
                                            return &StdioServer{
                                                server: server,
                                                errLogger: log.New(
                                                    os.Stderr,
                                                    "",
                                                    log.LstdFlags,
                                                ), // Default to discarding logs
                                            }
                                        }
                                          func (s *StdioServer) Listen(
                                              ctx context.Context,
                                              stdin io.Reader,
                                              stdout io.Writer,
                                          error {
                                              // Set a static client context since stdio only has one client
                                              if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
                                                  return fmt.Errorf("register session: %w", err)
                                              }
                                              defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
                                              reader := bufio.NewReader(stdin)
                                              go s.handleNotifications(ctx, stdout)
                                              return s.processInputStream(ctx, reader, stdout)
                                            var stdioSessionInstance = stdioSession{
                                                notifications: make(chan mcp.JSONRPCNotification, 100),
                                            }
                                              func (s *stdioSession) SessionID() string {
                                                  return "stdio"
                                              }
                                              注册一个ID为stdio的session,并将标准输入封装成reader,启动一个协程,将从session中读取返回值,写入到标准输出
                                                func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
                                                    for {
                                                        select {
                                                        case notification := <-stdioSessionInstance.notifications:
                                                            if err := s.writeResponse(notification, stdout); err != nil {
                                                输出的格式是json格式
                                                  func (s *StdioServer) writeResponse(
                                                      response mcp.JSONRPCMessage,
                                                      writer io.Writer,
                                                  error {
                                                      responseBytes, err := json.Marshal(response)
                                                      if err != nil {
                                                          return err
                                                      }
                                                      // Write response followed by newline
                                                      if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
                                                  最后启动一个协程处理输入,它是一个for循环,不断从标准输入中,一行行读取数据,然后以json格式解析,交给HandleMessage处理后,最后写入到response的channel
                                                    func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
                                                        for {
                                                            if err := ctx.Err(); err != nil {
                                                                return err
                                                            }
                                                            line, err := s.readNextLine(ctx, reader)
                                                            if err != nil {
                                                                if err == io.EOF {
                                                                    return nil
                                                                }
                                                                s.errLogger.Printf("Error reading input: %v", err)
                                                                return err
                                                            }
                                                            if err := s.processMessage(ctx, line, stdout); err != nil {
                                                      func (s *StdioServer) processMessage(
                                                          ctx context.Context,
                                                          line string,
                                                          writer io.Writer,
                                                      error {
                                                          // Parse the message as raw JSON
                                                          var rawMessage json.RawMessage
                                                          if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
                                                              response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
                                                              return s.writeResponse(response, writer)
                                                          }
                                                          // Handle the message using the wrapped server
                                                          response := s.server.HandleMessage(ctx, rawMessage)
                                                          // Only write response if there is one (not for notifications)
                                                          if response != nil {
                                                              if err := s.writeResponse(response, writer); err != nil {
                                                      最后我们重点看下HandleMessage的实现:
                                                        func (s *MCPServer) HandleMessage(
                                                            ctx context.Context,
                                                            message json.RawMessage,
                                                        ) mcp.JSONRPCMessage {
                                                            // Add server to context
                                                            ctx = context.WithValue(ctx, serverKey{}, s)
                                                            var err *requestError
                                                            var baseMessage struct {
                                                                JSONRPC string        `json:"jsonrpc"`
                                                                Method  mcp.MCPMethod `json:"method"`
                                                                ID      any           `json:"id,omitempty"`
                                                                Result  any           `json:"result,omitempty"`
                                                            }
                                                            if err := json.Unmarshal(message, &baseMessage); err != nil {
                                                                return createErrorResponse(
                                                                    nil,
                                                                    mcp.PARSE_ERROR,
                                                                    "Failed to parse message",
                                                                )
                                                            }
                                                            handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message)
                                                            switch baseMessage.Method {
                                                            case mcp.MethodInitialize:
                                                                    s.hooks.beforeInitialize(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleInitialize(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                             case mcp.MethodPing:
                                                                    s.hooks.beforePing(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handlePing(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterPing(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                             case mcp.MethodResourcesList:
                                                                    s.hooks.beforeListResources(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleListResources(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterListResources(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                             case mcp.MethodResourcesTemplatesList:
                                                                    s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                             case mcp.MethodResourcesRead:
                                                                   s.hooks.beforeReadResource(ctx, baseMessage.ID, &request)
                                                                   result, err = s.handleReadResource(ctx, baseMessage.ID, request)
                                                                   s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result)
                                                                   return createResponse(baseMessage.ID, *result)
                                                             case mcp.MethodPromptsList:
                                                                     s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleListPrompts(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result)
                                                                   return createResponse(baseMessage.ID, *result)
                                                             case mcp.MethodPromptsGet:
                                                                    s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleGetPrompt(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                              case mcp.MethodToolsList:
                                                                    s.hooks.beforeListTools(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleListTools(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterListTools(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                               case mcp.MethodToolsCall:
                                                                    s.hooks.beforeCallTool(ctx, baseMessage.ID, &request)
                                                                    result, err = s.handleToolCall(ctx, baseMessage.ID, request)
                                                                    s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result)
                                                                    return createResponse(baseMessage.ID, *result)
                                                                拿到请求后,先按照json-rpc格式解析请求,得到请求路径,然后分别对
                                                        initialize、ping、resources/list、resources/templates/list、resources/read、prompts/list,prompts/get、tools/list、tools/call
                                                        这几种路径进行处理。
                                                                tools/list的处理如下
                                                          func (s *MCPServer) handleListTools(
                                                              ctx context.Context,
                                                              id any,
                                                              request mcp.ListToolsRequest,
                                                          ) (*mcp.ListToolsResult, *requestError) {
                                                             for name := range s.tools {
                                                                  toolNames = append(toolNames, name)
                                                              }
                                                              // Sort the tool names for consistent ordering
                                                              sort.Strings(toolNames)
                                                              // Add tools in sorted order
                                                              for _, name := range toolNames {
                                                                  tools = append(tools, s.tools[name].Tool)
                                                              }
                                                              toolMap := make(map[string]mcp.Tool)
                                                               // Add global tools first
                                                                 for _, tool := range tools {
                                                                 toolMap[tool.Name] = tool
                                                                 }
                                                                // Then override with session-specific tools
                                                                for name, serverTool := range sessionTools {
                                                                 toolMap[name] = serverTool.Tool
                                                                   }
                                                                   // Convert back to slice
                                                                  tools = make([]mcp.Tool, 0len(toolMap))
                                                                 for _, tool := range toolMap {
                                                                  tools = append(tools, tool)
                                                                }
                                                                  tools/call的处理如下,先通过请求里的toolName到map里取出tool,如果session里没有取到,就到全局去取。获取tool的Handler函数,加上middleware后,调用函数处理获取结果。
                                                            func (s *MCPServer) handleToolCall(
                                                                ctx context.Context,
                                                                id any,
                                                                request mcp.CallToolRequest,
                                                            ) (*mcp.CallToolResult, *requestError) {
                                                                session := ClientSessionFromContext(ctx)
                                                                if session != nil {
                                                                    if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk {
                                                                        if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil {
                                                                            var sessionOk bool
                                                                            tool, sessionOk = sessionTools[request.Params.Name]
                                                                if !ok {
                                                                    s.toolsMu.RLock()
                                                                    tool, ok = s.tools[request.Params.Name]
                                                                    s.toolsMu.RUnlock()
                                                                }
                                                                


                                                                finalHandler := tool.Handler
                                                                s.middlewareMu.RLock()
                                                                mw := s.toolHandlerMiddlewares
                                                                s.middlewareMu.RUnlock()
                                                                // Apply middlewares in reverse order
                                                                for i := len(mw) - 1; i >= 0; i-- {
                                                                    finalHandler = mw[i](finalHandler)
                                                                }
                                                                result, err := finalHandler(ctx, request)
                                                            至此,完整的stdio格式的mcp请求处理流程源码分析完毕。

                                                            文章转载自golang算法架构leetcode技术php,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

                                                            评论