暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

gomock 源码分析

在分析完gostub的源码实现后https://mp.weixin.qq.com/s/I6urCBHbcfZCNaWw1iZmnA,我们分析下gomock的源码实现。两者原理相似,但是也有不同。

1,gomock的用法

gomock 有两种模式,反射模式和源码模式,都是通过接口生成对应的mock代码,测试的时候直接调用mock对象,相对于gostub不会进行地址替换更安全,且是并发安全的。

反射模式

通过构建一个程序用反射理解接口生成一个mock类文件,它通过两个非标志参数生效:导入路径和用逗号分隔的符号列表(多个interface)。

mockgen -destination mock_sql_driver.go database/sql/driver Conn,Driver


源码模式

通过一个包含interface定义的文件生成mock类文件,它通过 -source 标识生效,-imports 和 -aux_files 标识在这种模式下也是有用的。


mockgen -source=exp1/foo.go -destination=exp1/mock/mock_foo.go


mock控制器

mock控制器通过NewController接口生成,是mock生态系统的顶层控制,它定义了mock对象的作用域和生命周期,以及它们的期望。多个协程同时调用控制器的方法是安全的。

当用例结束后,控制器会检查所有剩余期望的调用是否满足条件。


mock对象的行为注入

对于mock对象的行为注入,控制器是通过map来维护的,一个方法对应map的一项。因为一个方法在一个用例中可能调用多次,所以map的值类型是数组切片。当mock对象进行行为注入时,控制器会将行为Add。当该方法被调用时,控制器会将该行为Remove。



行为调用的保序

默认情况下,行为调用顺序可以和mock对象行为注入顺序不一致,即不保序。如果要保序,有两种方法:


通过After关键字来实现保序

通过InOrder关键字来实现保序



关键字InOrder是After的语法糖,源码如下:

// InOrder declares that the given calls should occur in order.

func InOrder(calls ...*Call) {

for i := 1; i < len(calls); i++ {

calls[i].After(calls[i-1])

}

}


具体使用一个例子

    package foo


    import (
    "testing"
    "time"


    "github.com/golang/mock/gomock"


    mock_foo "gomock_study/exp1/mock"
    )


    func TestSUT(t *testing.T) {
    ctl := gomock.NewController(t)
    defer ctl.Finish()


    m := mock_foo.NewMockFoo(ctl)
    bar := m.
    EXPECT().
    Bar(gomock.Eq(99)).
    DoAndReturn(func(_ int) int {
    time.Sleep(1 * time.Second)
    return 101
    }).
    AnyTimes()
    m.EXPECT().Bar1(gomock.Any()).After(bar)


    // Does not make any assertions. Returns 103 when Bar is invoked with 101.
    m.
    EXPECT().
    Bar(gomock.Eq(101)).
    Return(103).
    AnyTimes()


    SUT(m)
    type args struct {
    f Foo
    }
    tests := []struct {
    name string
    args args
    }{
    // TODO: Add test cases.
    {
    name: "case1",
    args: args{
    f: m,
    },
    },
    }
    for _, tt := range tests {
    t.Run(tt.name, func(t *testing.T) {
    SUT(tt.args.f)
    })
    }
    }


    2gomock源码分析

    比如上面的例子,接口定义如下

      package foo


      import "fmt"


      type Foo interface {
      Bar(x int) int
      Bar1(x int) int
      }


      func SUT(f Foo) {
      // ...
      if 99 == f.Bar(88) {
      fmt.Print("ok")
      }
      }



      生成的代码如下

        // Code generated by MockGen. DO NOT EDIT.
        // Source: exp1/foo.go


        // Package mock_foo is a generated GoMock package.
        package mock_foo


        import (
        reflect "reflect"


        gomock "github.com/golang/mock/gomock"
        )


        // MockFoo is a mock of Foo interface.
        type MockFoo struct {
        ctrl *gomock.Controller
        recorder *MockFooMockRecorder
        }


        // MockFooMockRecorder is the mock recorder for MockFoo.
        type MockFooMockRecorder struct {
        mock *MockFoo
        }


        // NewMockFoo creates a new mock instance.
        func NewMockFoo(ctrl *gomock.Controller) *MockFoo {
        mock := &MockFoo{ctrl: ctrl}
        mock.recorder = &MockFooMockRecorder{mock}
        return mock
        }


        // EXPECT returns an object that allows the caller to indicate expected use.
        func (m *MockFoo) EXPECT() *MockFooMockRecorder {
        return m.recorder
        }


        // Bar mocks base method.
        func (m *MockFoo) Bar(x int) int {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "Bar", x)
        ret0, _ := ret[0].(int)
        return ret0
        }


        // Bar indicates an expected call of Bar.
        func (mr *MockFooMockRecorder) Bar(x interface{}) *gomock.Call {
        mr.mock.ctrl.T.Helper()
        return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar", reflect.TypeOf((*MockFoo)(nil).Bar), x)
        }


        // Bar1 mocks base method.
        func (m *MockFoo) Bar1(x int) int {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "Bar1", x)
        ret0, _ := ret[0].(int)
        return ret0
        }


        // Bar1 indicates an expected call of Bar1.
        func (mr *MockFooMockRecorder) Bar1(x interface{}) *gomock.Call {
        mr.mock.ctrl.T.Helper()
        return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar1", reflect.TypeOf((*MockFoo)(nil).Bar1), x)
        }


        我们可以看到,对于每个接口都会生成一个对应的结构体,并且实现了接口的方法

        同时还定义了一个结构体的recoder,recorder实现了同名方法。为什么要定义一个recorder呢?


        我们可以看到生成的结构体有两个属性,一个就是全局数据管理的controller,另一个就是recorder

          type MockFoo struct {
          ctrl *gomock.Controller
          recorder *MockFooMockRecorder
          }

          recorder里面定义了这个结构体的指针

            type MockFooMockRecorder struct {
            mock *MockFoo
            }

            gomock的测试分两步

            1,写打桩方法:

                bar := m.
              EXPECT().
              Bar(gomock.Eq(99)).
              DoAndReturn(func(_ int) int {
              time.Sleep(1 * time.Second)
              return 101
              }).
              AnyTimes()

              2,进行函数调用:

                  SUT(tt.args.f)

                里面调用了接口的一个函数

                  if 99 == f.Bar(88


                  我们写打桩方法的时候为啥要加上EXPECT()呢?

                  我们先看下EXPECT()内部的实现

                    func (m *MockFoo) EXPECT() *MockFooMockRecorder {
                    return m.recorder
                    }

                    其实是返回了生成结构体的recorder对象。


                    那么就不难理解了,打桩的时候是调用的recorder的方法,函数调用的时候调用的是生成结构体的方法


                    下面我们看下函数内部的具体实现

                    1,打桩方法

                      // Bar indicates an expected call of Bar.
                      func (mr *MockFooMockRecorder) Bar(x interface{}) *gomock.Call {
                      mr.mock.ctrl.T.Helper()
                      return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar", reflect.TypeOf((*MockFoo)(nil).Bar), x)
                      }

                      打桩的时候只是将对象和对应的方法注入存储起来。

                      2,结构体方法

                        // Bar mocks base method.
                        func (m *MockFoo) Bar(x int) int {
                        m.ctrl.T.Helper()
                        ret := m.ctrl.Call(m, "Bar", x)
                        ret0, _ := ret[0].(int)
                        return ret0
                        }

                        真正调用的时候是用来controller的call方法进行revoke


                        下面看下方法注入的具体实现

                          func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
                          ctrl.T.Helper()


                          call := newCall(ctrl.T, receiver, method, methodType, args...)


                          ctrl.mu.Lock()
                          defer ctrl.mu.Unlock()
                          ctrl.expectedCalls.Add(call)


                          return call
                          }


                          先构建了一个Call对象,然后将它存储在controller的expectedCalls里面

                          我们先看下Controller的定义

                            type Controller struct {
                            // T should only be called within a generated mock. It is not intended to
                            // be used in user code and may be changed in future versions. T is the
                            // TestReporter passed in when creating the Controller via NewController.
                            // If the TestReporter does not implement a TestHelper it will be wrapped
                            // with a nopTestHelper.
                            T TestHelper
                            mu sync.Mutex
                            expectedCalls *callSet
                            finished bool
                            }

                            Controller里最重要的属性就是expectedCalls,里面是两个map,map的key是对象和对象对应的方法,值就是Call的指针一个slice,第一个map是期望的调用,第二个map是超过期望调用次数的调用的一个存储

                              type callSet struct {
                              Calls that are still expected.
                              expected map[callSetKey][]*Call
                              Calls that have been exhausted.
                              exhausted map[callSetKey][]*Call
                              }

                              具体看下callSetKey的定义,里面存储了接口和函数名

                                type callSetKey struct {
                                receiver interface{}
                                fname string
                                }

                                然后我们看看call的定义,里面比较复杂,主要存储了函数调用相关的信息

                                  type Call struct {
                                  t TestHelper for triggering test failures on invalid call setup


                                  receiver interface{} the receiver of the method call
                                  method string the name of the method
                                  methodType reflect.Type the type of the method
                                  args []Matcher the args
                                  origin string file and line number of call setup


                                  preReqs []*Call // prerequisite calls


                                  // Expectations
                                  minCalls, maxCalls int


                                  numCalls int // actual number made


                                  // actions are called when this Call is called. Each action gets the args and
                                  // can set the return values by returning a non-nil slice. Actions run in the
                                  // order they are created.
                                  actions []func([]interface{}) []interface{}
                                  }

                                  我们回来看打桩的过程,在构建了Call对象以后呢,会将Call对象存入到expectedCalls,用到的add方法其实就是一个append操作

                                    // Add adds a new expected call.
                                    func (cs callSet) Add(call *Call) {
                                    key := callSetKey{call.receiver, call.method}
                                    m := cs.expected
                                    if call.exhausted() {
                                    m = cs.exhausted
                                    }
                                    m[key] = append(m[key], call)
                                    }

                                    在完成打桩依赖注入后,我们看看具体函数调用的过程,它调了Call方法

                                      func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
                                      ctrl.T.Helper()


                                      // Nest this code so we can use defer to make sure the lock is released.
                                      actions := func() []func([]interface{}) []interface{} {
                                      ctrl.T.Helper()
                                      ctrl.mu.Lock()
                                      defer ctrl.mu.Unlock()


                                      expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
                                      if err != nil {
                                      // callerInfo's skip should be updated if the number of calls between the user's test
                                      // and this line changes, i.e. this code is wrapped in another anonymous function.
                                      // 0 is us, 1 is controller.Call(), 2 is the generated mock, and 3 is the user's test.
                                      origin := callerInfo(3)
                                      ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
                                      }


                                      // Two things happen here:
                                      // * the matching call no longer needs to check prerequite calls,
                                      // * and the prerequite calls are no longer expected, so remove them.
                                      preReqCalls := expected.dropPrereqs()
                                      for _, preReqCall := range preReqCalls {
                                      ctrl.expectedCalls.Remove(preReqCall)
                                      }


                                      actions := expected.call()
                                      if expected.exhausted() {
                                      ctrl.expectedCalls.Remove(expected)
                                      }
                                      return actions
                                      }()


                                      var rets []interface{}
                                      for _, action := range actions {
                                      if r := action(args); r != nil {
                                      rets = r
                                      }
                                      }


                                      return rets
                                      }

                                      1,通过

                                         ctrl.expectedCalls.FindMatch(receiver, method, args)

                                        2,找到我们打桩的时候注入的方法,然后执行方法,获取返回值

                                          if r := action(args); r != nil 

                                          可以看下FindMatch的具体实现

                                            func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
                                            key := callSetKey{receiver, method}


                                            // Search through the expected calls.
                                            expected := cs.expected[key]
                                            var callsErrors bytes.Buffer
                                            for _, call := range expected {
                                            err := call.matches(args)
                                            if err != nil {
                                            _, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
                                            } else {
                                            return call, nil
                                            }
                                            }


                                            // If we haven't found a match then search through the exhausted calls so we
                                            // get useful error messages.
                                            exhausted := cs.exhausted[key]
                                            for _, call := range exhausted {
                                            if err := call.matches(args); err != nil {
                                            _, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
                                            continue
                                            }
                                            _, _ = fmt.Fprintf(
                                            &callsErrors, "all expected calls for method %q have been exhausted", method,
                                            )
                                            }


                                            if len(expected)+len(exhausted) == 0 {
                                            _, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
                                            }


                                            return nil, fmt.Errorf(callsErrors.String())
                                            }

                                            其实就是通过key和参数对比,找到注入的实现。

                                            文章转载自golang算法架构leetcode技术php,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

                                            评论