Skip to content

Commit a4116cb

Browse files
avoid data races in Arguments.Diff
Fixes a concurrency issue that would lead to testify mocks producing data races detected by go test -race. These data races would occur whenever a mock pointer argument was concurrently modified. The reason being that Arguments.Diff uses the %v format specifier to get a presentable string for the argument. This also traverses the pointed-to data structure, which would lead to the data race. Signed-off-by: Peter Gardfjäll <[email protected]>
1 parent 8d4dcbb commit a4116cb

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

mock/mock.go

+23-2
Original file line numberDiff line numberDiff line change
@@ -947,15 +947,31 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
947947
actualFmt = "(Missing)"
948948
} else {
949949
actual = objects[i]
950-
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
950+
// Note: avoid %v format specifier for pointer arguments. The %v format
951+
// specifier traverses the data structure, and for situations where the
952+
// argument is a pointer (that may be updated concurrently) this can result
953+
// in the mock code causing a data race when running go test -race.
954+
if isPtr(actual) {
955+
actualFmt = fmt.Sprintf("(%[1]T=%[1]p)", &actual)
956+
} else {
957+
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
958+
}
951959
}
952960

953961
if len(args) <= i {
954962
expected = "(Missing)"
955963
expectedFmt = "(Missing)"
956964
} else {
957965
expected = args[i]
958-
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
966+
// Note: avoid %v format specifier for pointer arguments. The %v format
967+
// specifier traverses the data structure, and for situations where the
968+
// argument is a pointer (that may be updated concurrently) this can result
969+
// in the mock code causing a data race when running go test -race.
970+
if isPtr(expected) {
971+
expectedFmt = fmt.Sprintf("(%[1]T=%[1]p)", expected)
972+
} else {
973+
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
974+
}
959975
}
960976

961977
if matcher, ok := expected.(argumentMatcher); ok {
@@ -1250,3 +1266,8 @@ func funcName(opt interface{}) string {
12501266
n := runtime.FuncForPC(reflect.ValueOf(opt).Pointer()).Name()
12511267
return strings.TrimSuffix(path.Base(n), path.Ext(n))
12521268
}
1269+
1270+
// isPtr indicates if the supplied value is a pointer.
1271+
func isPtr(v interface{}) bool {
1272+
return reflect.ValueOf(v).Kind() == reflect.Ptr
1273+
}

mock/mock_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,42 @@ func Test_MockReturnAndCalledConcurrent(t *testing.T) {
19111911
wg.Wait()
19121912
}
19131913

1914+
type argType struct{ Question string }
1915+
1916+
type pointerArgMock struct{ Mock }
1917+
1918+
func (m *pointerArgMock) Question(arg *argType) int {
1919+
args := m.Called(arg)
1920+
return args.Int(0)
1921+
}
1922+
1923+
// Exercises calling a mock with a pointer value that gets modified concurrently. Prior to fix
1924+
// TODO:pr this would fail when running go test with the -race flag, due to Arguments.Diff printing
1925+
// the format with specifier %v which traverses the pointed to data structure (that is being
1926+
// concurrently modified by another goroutine).
1927+
func Test_CallMockWithConcurrentlyModifiedPointerArg(t *testing.T) {
1928+
m := &pointerArgMock{}
1929+
m.On("Question", Anything).Return(42)
1930+
1931+
ptrArg := &argType{Question: "What's the meaning of life?"}
1932+
1933+
// Emulates a situation where the pointer value gets concurrently updated by another thread.
1934+
wg := sync.WaitGroup{}
1935+
wg.Add(1)
1936+
go func() {
1937+
defer wg.Done()
1938+
ptrArg.Question = "What is 7 * 6?"
1939+
}()
1940+
1941+
// This is where we would get a data race since Arguments.Diff would traverse the pointed to
1942+
// struct while being updated. Something go test -race would identify as a data race.
1943+
value := m.Question(ptrArg)
1944+
assert.Equal(t, 42, value)
1945+
wg.Wait()
1946+
1947+
m.AssertExpectations(t)
1948+
}
1949+
19141950
type timer struct{ Mock }
19151951

19161952
func (s *timer) GetTime(i int) string {

0 commit comments

Comments
 (0)