在分析完gostub的源码实现后https://mp.weixin.qq.com/s/I6urCBHbcfZCNaWw1iZmnA,我们分析下gomock的源码实现。两者原理相似,但是也有不同。
1,gomock的用法
gomock 有两种模式,反射模式和源码模式,都是通过接口生成对应的mock代码,测试的时候直接调用mock对象,相对于gostub不会进行地址替换更安全,且是并发安全的。
反射模式
通过构建一个程序用反射理解接口生成一个mock类文件,它通过两个非标志参数生效:导入路径和用逗号分隔的符号列表(多个interface)。
mockgen -destination mock_sql_driver.go database/sql/driver Conn,Driver
源码模式
通过一个包含interface定义的文件生成mock类文件,它通过 -source 标识生效,-imports 和 -aux_files 标识在这种模式下也是有用的。
mockgen -source=exp1/foo.go -destination=exp1/mock/mock_foo.go
mock控制器
mock控制器通过NewController接口生成,是mock生态系统的顶层控制,它定义了mock对象的作用域和生命周期,以及它们的期望。多个协程同时调用控制器的方法是安全的。
当用例结束后,控制器会检查所有剩余期望的调用是否满足条件。
mock对象的行为注入
对于mock对象的行为注入,控制器是通过map来维护的,一个方法对应map的一项。因为一个方法在一个用例中可能调用多次,所以map的值类型是数组切片。当mock对象进行行为注入时,控制器会将行为Add。当该方法被调用时,控制器会将该行为Remove。
行为调用的保序
默认情况下,行为调用顺序可以和mock对象行为注入顺序不一致,即不保序。如果要保序,有两种方法:
通过After关键字来实现保序
通过InOrder关键字来实现保序
关键字InOrder是After的语法糖,源码如下:
// InOrder declares that the given calls should occur in order.
func InOrder(calls ...*Call) {
for i := 1; i < len(calls); i++ {
calls[i].After(calls[i-1])
}
}
具体使用一个例子
package fooimport ("testing""time""github.com/golang/mock/gomock"mock_foo "gomock_study/exp1/mock")func TestSUT(t *testing.T) {ctl := gomock.NewController(t)defer ctl.Finish()m := mock_foo.NewMockFoo(ctl)bar := m.EXPECT().Bar(gomock.Eq(99)).DoAndReturn(func(_ int) int {time.Sleep(1 * time.Second)return 101}).AnyTimes()m.EXPECT().Bar1(gomock.Any()).After(bar)// Does not make any assertions. Returns 103 when Bar is invoked with 101.m.EXPECT().Bar(gomock.Eq(101)).Return(103).AnyTimes()SUT(m)type args struct {f Foo}tests := []struct {name stringargs args}{// TODO: Add test cases.{name: "case1",args: args{f: m,},},}for _, tt := range tests {t.Run(tt.name, func(t *testing.T) {SUT(tt.args.f)})}}
2,gomock源码分析
比如上面的例子,接口定义如下
package fooimport "fmt"type Foo interface {Bar(x int) intBar1(x int) int}func SUT(f Foo) {// ...if 99 == f.Bar(88) {fmt.Print("ok")}}
生成的代码如下
// Code generated by MockGen. DO NOT EDIT.// Source: exp1/foo.go// Package mock_foo is a generated GoMock package.package mock_fooimport (reflect "reflect"gomock "github.com/golang/mock/gomock")// MockFoo is a mock of Foo interface.type MockFoo struct {ctrl *gomock.Controllerrecorder *MockFooMockRecorder}// MockFooMockRecorder is the mock recorder for MockFoo.type MockFooMockRecorder struct {mock *MockFoo}// NewMockFoo creates a new mock instance.func NewMockFoo(ctrl *gomock.Controller) *MockFoo {mock := &MockFoo{ctrl: ctrl}mock.recorder = &MockFooMockRecorder{mock}return mock}// EXPECT returns an object that allows the caller to indicate expected use.func (m *MockFoo) EXPECT() *MockFooMockRecorder {return m.recorder}// Bar mocks base method.func (m *MockFoo) Bar(x int) int {m.ctrl.T.Helper()ret := m.ctrl.Call(m, "Bar", x)ret0, _ := ret[0].(int)return ret0}// Bar indicates an expected call of Bar.func (mr *MockFooMockRecorder) Bar(x interface{}) *gomock.Call {mr.mock.ctrl.T.Helper()return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar", reflect.TypeOf((*MockFoo)(nil).Bar), x)}// Bar1 mocks base method.func (m *MockFoo) Bar1(x int) int {m.ctrl.T.Helper()ret := m.ctrl.Call(m, "Bar1", x)ret0, _ := ret[0].(int)return ret0}// Bar1 indicates an expected call of Bar1.func (mr *MockFooMockRecorder) Bar1(x interface{}) *gomock.Call {mr.mock.ctrl.T.Helper()return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar1", reflect.TypeOf((*MockFoo)(nil).Bar1), x)}
我们可以看到,对于每个接口都会生成一个对应的结构体,并且实现了接口的方法。
同时还定义了一个结构体的recoder,recorder实现了同名方法。为什么要定义一个recorder呢?
我们可以看到生成的结构体有两个属性,一个就是全局数据管理的controller,另一个就是recorder
type MockFoo struct {ctrl *gomock.Controllerrecorder *MockFooMockRecorder}
recorder里面定义了这个结构体的指针
type MockFooMockRecorder struct {mock *MockFoo}
gomock的测试分两步
1,写打桩方法:
bar := m.EXPECT().Bar(gomock.Eq(99)).DoAndReturn(func(_ int) int {time.Sleep(1 * time.Second)return 101}).AnyTimes()
2,进行函数调用:
SUT(tt.args.f)
里面调用了接口的一个函数
if 99 == f.Bar(88)
我们写打桩方法的时候为啥要加上EXPECT()呢?
我们先看下EXPECT()内部的实现
func (m *MockFoo) EXPECT() *MockFooMockRecorder {return m.recorder}
其实是返回了生成结构体的recorder对象。
那么就不难理解了,打桩的时候是调用的recorder的方法,函数调用的时候调用的是生成结构体的方法。
下面我们看下函数内部的具体实现
1,打桩方法
// Bar indicates an expected call of Bar.func (mr *MockFooMockRecorder) Bar(x interface{}) *gomock.Call {mr.mock.ctrl.T.Helper()return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar", reflect.TypeOf((*MockFoo)(nil).Bar), x)}
打桩的时候只是将对象和对应的方法注入存储起来。
2,结构体方法
// Bar mocks base method.func (m *MockFoo) Bar(x int) int {m.ctrl.T.Helper()ret := m.ctrl.Call(m, "Bar", x)ret0, _ := ret[0].(int)return ret0}
真正调用的时候是用来controller的call方法进行revoke
下面看下方法注入的具体实现
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {ctrl.T.Helper()call := newCall(ctrl.T, receiver, method, methodType, args...)ctrl.mu.Lock()defer ctrl.mu.Unlock()ctrl.expectedCalls.Add(call)return call}
先构建了一个Call对象,然后将它存储在controller的expectedCalls里面
我们先看下Controller的定义
type Controller struct {// T should only be called within a generated mock. It is not intended to// be used in user code and may be changed in future versions. T is the// TestReporter passed in when creating the Controller via NewController.// If the TestReporter does not implement a TestHelper it will be wrapped// with a nopTestHelper.T TestHelpermu sync.MutexexpectedCalls *callSetfinished bool}
Controller里最重要的属性就是expectedCalls,里面是两个map,map的key是对象和对象对应的方法,值就是Call的指针一个slice,第一个map是期望的调用,第二个map是超过期望调用次数的调用的一个存储
type callSet struct {Calls that are still expected.expected map[callSetKey][]*CallCalls that have been exhausted.exhausted map[callSetKey][]*Call}
具体看下callSetKey的定义,里面存储了接口和函数名
type callSetKey struct {receiver interface{}fname string}
然后我们看看call的定义,里面比较复杂,主要存储了函数调用相关的信息
type Call struct {t TestHelper for triggering test failures on invalid call setupreceiver interface{} the receiver of the method callmethod string the name of the methodmethodType reflect.Type the type of the methodargs []Matcher the argsorigin string file and line number of call setuppreReqs []*Call // prerequisite calls// ExpectationsminCalls, maxCalls intnumCalls int // actual number made// actions are called when this Call is called. Each action gets the args and// can set the return values by returning a non-nil slice. Actions run in the// order they are created.actions []func([]interface{}) []interface{}}
我们回来看打桩的过程,在构建了Call对象以后呢,会将Call对象存入到expectedCalls,用到的add方法其实就是一个append操作
// Add adds a new expected call.func (cs callSet) Add(call *Call) {key := callSetKey{call.receiver, call.method}m := cs.expectedif call.exhausted() {m = cs.exhausted}m[key] = append(m[key], call)}
在完成打桩依赖注入后,我们看看具体函数调用的过程,它调了Call方法
func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {ctrl.T.Helper()// Nest this code so we can use defer to make sure the lock is released.actions := func() []func([]interface{}) []interface{} {ctrl.T.Helper()ctrl.mu.Lock()defer ctrl.mu.Unlock()expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)if err != nil {// callerInfo's skip should be updated if the number of calls between the user's test// and this line changes, i.e. this code is wrapped in another anonymous function.// 0 is us, 1 is controller.Call(), 2 is the generated mock, and 3 is the user's test.origin := callerInfo(3)ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)}// Two things happen here:// * the matching call no longer needs to check prerequite calls,// * and the prerequite calls are no longer expected, so remove them.preReqCalls := expected.dropPrereqs()for _, preReqCall := range preReqCalls {ctrl.expectedCalls.Remove(preReqCall)}actions := expected.call()if expected.exhausted() {ctrl.expectedCalls.Remove(expected)}return actions}()var rets []interface{}for _, action := range actions {if r := action(args); r != nil {rets = r}}return rets}
1,通过
ctrl.expectedCalls.FindMatch(receiver, method, args)
2,找到我们打桩的时候注入的方法,然后执行方法,获取返回值
if r := action(args); r != nil
可以看下FindMatch的具体实现
func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {key := callSetKey{receiver, method}// Search through the expected calls.expected := cs.expected[key]var callsErrors bytes.Bufferfor _, call := range expected {err := call.matches(args)if err != nil {_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)} else {return call, nil}}// If we haven't found a match then search through the exhausted calls so we// get useful error messages.exhausted := cs.exhausted[key]for _, call := range exhausted {if err := call.matches(args); err != nil {_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)continue}_, _ = fmt.Fprintf(&callsErrors, "all expected calls for method %q have been exhausted", method,)}if len(expected)+len(exhausted) == 0 {_, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)}return nil, fmt.Errorf(callsErrors.String())}
其实就是通过key和参数对比,找到注入的实现。






