diff --git a/cmd/mockery.go b/cmd/mockery.go index 96dc6499..e14302c8 100644 --- a/cmd/mockery.go +++ b/cmd/mockery.go @@ -46,20 +46,10 @@ func NewRootCmd() (*cobra.Command, error) { pFlags := cmd.PersistentFlags() pFlags.StringVar(&cfgFile, "config", "", "config file to use") - pFlags.String("dir", "", "directory to search for interfaces") - pFlags.BoolP("recursive", "r", false, "recurse search into sub-directories") - pFlags.StringArray("exclude", nil, "prefixes of subdirectories and files to exclude from search") - pFlags.Bool("all", false, "generates mocks for all found interfaces in all sub-directories") - pFlags.String("note", "", "comment to insert into prologue of each generated file") - pFlags.String("cpuprofile", "", "write cpu profile to file") pFlags.Bool("version", false, "prints the installed version of mockery") pFlags.String("tags", "", "space-separated list of additional build tags to load packages") pFlags.String("mock-build-tags", "", "set the build tags of the generated mocks. Read more about the format: https://pkg.go.dev/cmd/go#hdr-Build_constraints") - pFlags.String("filename", "", "name of generated file (only works with -name and no regex)") - pFlags.String("structname", "", "name of generated struct (only works with -name and no regex)") pFlags.String("log-level", "info", "Level of logging") - pFlags.String("srcpkg", "", "source pkg to search for interfaces") - pFlags.BoolP("dry-run", "d", false, "Do a dry run, don't modify any files") pFlags.String("boilerplate-file", "", "File to read a boilerplate text from. Text should be a go block comment, i.e. /* ... */") pFlags.Bool("unroll-variadic", true, "For functions with variadic arguments, do not unroll the arguments into the underlying testify call. Instead, pass variadic slice as-is.") diff --git a/mockery-tools.env b/mockery-tools.env index 2f9b2300..91f1d5b1 100644 --- a/mockery-tools.env +++ b/mockery-tools.env @@ -1 +1 @@ -VERSION=v3.0.0-alpha.7 +VERSION=v3.0.0-alpha.8 diff --git a/pkg/config.go b/pkg/config.go index 26adb071..441a8648 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -15,13 +15,13 @@ import ( "strings" "github.com/chigopher/pathlib" - "github.com/huandu/xstrings" "github.com/jinzhu/copier" "github.com/mitchellh/mapstructure" "github.com/rs/zerolog" "github.com/spf13/viper" "github.com/vektra/mockery/v3/pkg/logging" "github.com/vektra/mockery/v3/pkg/stackerr" + mockeryTemplate "github.com/vektra/mockery/v3/pkg/template" "golang.org/x/tools/go/packages" "gopkg.in/yaml.v3" ) @@ -66,7 +66,6 @@ func NewConfigFromViper(v *viper.Viper) (*Config, error) { v.SetDefault("formatter", "goimports") v.SetDefault("mockname", "Mock{{.InterfaceName}}") v.SetDefault("pkgname", "{{.SrcPackageName}}") - v.SetDefault("dry-run", false) v.SetDefault("log-level", "info") if err := v.UnmarshalExact(c); err != nil { @@ -757,50 +756,6 @@ func (c *Config) TagName(name string) string { var ErrInfiniteLoop = fmt.Errorf("infinite loop in template variables detected") -// Functions available in the template for manipulating -// -// Since the map and its functions are stateless, it exists as -// a package var rather than being initialized on every call -// in [parseConfigTemplates] and [generator.printTemplate] -var templateFuncMap = template.FuncMap{ - // String inspection and manipulation - "contains": strings.Contains, - "hasPrefix": strings.HasPrefix, - "hasSuffix": strings.HasSuffix, - "join": strings.Join, - "replace": strings.Replace, - "replaceAll": strings.ReplaceAll, - "split": strings.Split, - "splitAfter": strings.SplitAfter, - "splitAfterN": strings.SplitAfterN, - "trim": strings.Trim, - "trimLeft": strings.TrimLeft, - "trimPrefix": strings.TrimPrefix, - "trimRight": strings.TrimRight, - "trimSpace": strings.TrimSpace, - "trimSuffix": strings.TrimSuffix, - "lower": strings.ToLower, - "upper": strings.ToUpper, - "camelcase": xstrings.ToCamelCase, - "snakecase": xstrings.ToSnakeCase, - "kebabcase": xstrings.ToKebabCase, - "firstLower": xstrings.FirstRuneToLower, - "firstUpper": xstrings.FirstRuneToUpper, - - // Regular expression matching - "matchString": regexp.MatchString, - "quoteMeta": regexp.QuoteMeta, - - // Filepath manipulation - "base": filepath.Base, - "clean": filepath.Clean, - "dir": filepath.Dir, - - // Basic access to reading environment variables - "expandEnv": os.ExpandEnv, - "getenv": os.Getenv, -} - // ParseTemplates parses various templated strings // in the config struct into their fully defined values. This mutates // the config object passed. An *Interface object can be supplied to satisfy @@ -891,7 +846,7 @@ func (c *Config) ParseTemplates(ctx context.Context, iface *Interface, srcPkg *p for name, attributePointer := range templateMap { oldVal := *attributePointer - attributeTempl, err := template.New("interface-template").Funcs(templateFuncMap).Parse(*attributePointer) + attributeTempl, err := template.New("config-template").Funcs(mockeryTemplate.StringManipulationFuncs).Parse(*attributePointer) if err != nil { return fmt.Errorf("failed to parse %s template: %w", name, err) } diff --git a/pkg/mockery.templ b/pkg/mockery.templ index 5bd8e137..d8d2a98c 100644 --- a/pkg/mockery.templ +++ b/pkg/mockery.templ @@ -10,7 +10,7 @@ package {{.PkgName}} import ( {{- range .Imports}} - {{. | ImportStatement}} + {{. | importStatement}} {{- end}} mock "github.com/stretchr/testify/mock" ) @@ -71,19 +71,19 @@ func (_mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{$method.Name}}({ {{- $calledString = "" }} {{- end }} - {{- $lastParam := index $method.Params (len $method.Params | Add -1 )}} + {{- $lastParam := index $method.Params (len $method.Params | add -1 )}} if len({{ $lastParam.Var.Name }}) > 0 { {{- if ne (len $method.Returns) 0}}tmpRet = {{ end }}_mock.Called({{- if (index $mock.TemplateData "unroll-variadic") }}{{ $method.ArgCallList }}{{- else }}{{ $method.ArgCallListNoEllipsis }}{{- end }}) } else { - {{- if ne (len $method.Returns) 0}}tmpRet = {{ end }}_mock.Called({{- if (index $mock.TemplateData "unroll-variadic") }}{{ $method.ArgCallListSlice 0 (len $method.Params | Add -1 )}}{{- else }}{{ $method.ArgCallListSliceNoEllipsis 0 (len $method.Params | Add -1 )}}{{- end }}) + {{- if ne (len $method.Returns) 0}}tmpRet = {{ end }}_mock.Called({{- if (index $mock.TemplateData "unroll-variadic") }}{{ $method.ArgCallListSlice 0 (len $method.Params | add -1 )}}{{- else }}{{ $method.ArgCallListSliceNoEllipsis 0 (len $method.Params | add -1 )}}{{- end }}) } {{- else }} {{- $calledString = printf "_mock.Called(%s)" $method.ArgCallList }} {{- end }} {{- else }} - {{- $lastParam := (index $method.Params (len $method.Params | Add -1)) }} + {{- $lastParam := (index $method.Params (len $method.Params | add -1)) }} {{- $variadicArgsName := $lastParam.Var.Name }} - {{- $strippedTypeString := TrimPrefix "..." $lastParam.TypeStringEllipsis }} + {{- $strippedTypeString := trimPrefix "..." $lastParam.TypeStringEllipsis }} {{- if and (ne $strippedTypeString "interface{}") (ne $strippedTypeString "any") }} // {{ $strippedTypeString }} @@ -95,7 +95,7 @@ func (_mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{$method.Name}}({ {{- end }} var _ca []interface{} {{- if gt (len $method.Params) 1 }} - _ca = append(_ca, {{ $method.ArgCallListSlice 0 (len $method.Params | Add -1) }}) + _ca = append(_ca, {{ $method.ArgCallListSlice 0 (len $method.Params | add -1) }}) {{- end }} _ca = append(_ca, {{ $variadicArgsName }}...) {{- $calledString = "_mock.Called(_ca...)" }} @@ -136,7 +136,7 @@ func (_mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{$method.Name}}({ } {{- end }} {{/* END RETURN RANGE */}} {{- end }} - return {{ range $retIdx, $ret := $method.Returns }}r{{ $retIdx }}{{ if ne $retIdx (len $method.Returns | Add -1) }}, {{ end }}{{ end }} + return {{ range $retIdx, $ret := $method.Returns }}r{{ $retIdx }}{{ if ne $retIdx (len $method.Returns | add -1) }}, {{ end }}{{ end }} } {{/* CREATE EXPECTER METHOD */}} @@ -161,7 +161,7 @@ func (_e *{{ $expecterNameInstantiated }}) {{ $method.Name }}({{ range $method.P {{- else }} append([]interface{}{ {{- range $i, $param := $method.Params }} - {{- if (lt $i (len $method.Params | Add -1 ))}} {{ $param.Var.Name }}, + {{- if (lt $i (len $method.Params | add -1 ))}} {{ $param.Var.Name }}, {{- else }} }, {{ $param.Var.Name }}... {{- end }} {{- end}} )... @@ -173,8 +173,8 @@ func (_c *{{ $ExpecterCallNameInstantiated }}) Run(run func({{ $method.ArgList } {{- if not $method.IsVariadic }} run({{range $i, $param := $method.Params }}args[{{$i}}].({{ $param.TypeString}}),{{end}}) {{- else}} - {{- $variadicParam := index $method.Params (len $method.Params | Add -1) }} - {{- $nonVariadicParams := slice $method.Params 0 (len $method.Params | Add -1 )}} + {{- $variadicParam := index $method.Params (len $method.Params | add -1) }} + {{- $nonVariadicParams := slice $method.Params 0 (len $method.Params | add -1 )}} variadicArgs := make([]{{ $variadicParam.TypeStringVariadicUnderlying }}, len(args) - {{len $nonVariadicParams}}) for i, a := range args[{{len $nonVariadicParams}}:] { if a != nil { diff --git a/pkg/moq.templ b/pkg/moq.templ index 8447d2a9..5a52bb77 100644 --- a/pkg/moq.templ +++ b/pkg/moq.templ @@ -5,9 +5,11 @@ package {{.PkgName}} import ( {{- range .Imports}} - {{. | ImportStatement}} + {{. | importStatement}} {{- end}} +{{- if .Mocks | mocksSomeMethod }} "sync" +{{- end }} "fmt" ) @@ -54,7 +56,7 @@ var _ {{$.SrcPkgQualifier}}{{.InterfaceName -}} type {{.MockName}} {{- if .TypeParams -}} [{{- range $index, $param := .TypeParams}} - {{- if $index}}, {{end}}{{$param.Name | Exported}} {{$param.TypeString}} + {{- if $index}}, {{end}}{{$param.Name | exported}} {{$param.TypeString}} {{- end -}}] {{- end }} struct { {{- range .Methods}} @@ -67,14 +69,14 @@ type {{.MockName}} // {{.Name}} holds details about calls to the {{.Name}} method. {{.Name}} []struct { {{- range .Params}} - // {{.Name | Exported}} is the {{.Name}} argument value. - {{.Name | Exported}} {{.TypeString}} + // {{.Name | exported}} is the {{.Name}} argument value. + {{.Name | exported}} {{.TypeString}} {{- end}} } {{- end}} } {{- range .Methods}} - lock{{.Name}} {{$.Imports | SyncPkgQualifier}}.RWMutex + lock{{.Name}} {{$.Imports | syncPkgQualifier}}.RWMutex {{- end}} } {{range .Methods}} @@ -87,11 +89,11 @@ func (mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{.Name}}({{.ArgLis {{- end}} callInfo := struct { {{- range .Params}} - {{.Name | Exported}} {{.TypeString}} + {{.Name | exported}} {{.TypeString}} {{- end}} }{ {{- range .Params}} - {{.Name | Exported}}: {{.Name}}, + {{.Name | exported}}: {{.Name}}, {{- end}} } mock.lock{{.Name}}.Lock() @@ -125,12 +127,12 @@ func (mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{.Name}}({{.ArgLis // len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls()) func (mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{.Name}}Calls() []struct { {{- range .Params}} - {{.Name | Exported}} {{.TypeString}} + {{.Name | exported}} {{.TypeString}} {{- end}} } { var calls []struct { {{- range .Params}} - {{.Name | Exported}} {{.TypeString}} + {{.Name | exported}} {{.TypeString}} {{- end}} } mock.lock{{.Name}}.RLock() diff --git a/pkg/template/template.go b/pkg/template/template.go index 003ccc70..ae714ff5 100644 --- a/pkg/template/template.go +++ b/pkg/template/template.go @@ -20,7 +20,15 @@ type Template struct { // New returns a new instance of Template. func New(templateString string, name string) (Template, error) { - tmpl, err := template.New(name).Funcs(templateFuncs).Parse(templateString) + mergedFuncMap := template.FuncMap{} + for key, val := range StringManipulationFuncs { + mergedFuncMap[key] = val + } + for key, val := range TemplateMockFuncs { + mergedFuncMap[key] = val + } + + tmpl, err := template.New(name).Funcs(mergedFuncMap).Parse(templateString) if err != nil { return Template{}, err } @@ -54,14 +62,14 @@ func exported(s string) string { return strings.ToUpper(s[0:1]) + s[1:] } -var templateFuncs = template.FuncMap{ - "ImportStatement": func(imprt *registry.Package) string { +var TemplateMockFuncs = template.FuncMap{ + "importStatement": func(imprt *registry.Package) string { if imprt.Alias == "" { return `"` + imprt.Path() + `"` } return imprt.Alias + ` "` + imprt.Path() + `"` }, - "SyncPkgQualifier": func(imports []*registry.Package) string { + "syncPkgQualifier": func(imports []*registry.Package) string { for _, imprt := range imports { if imprt.Path() == "sync" { return imprt.Qualifier() @@ -70,9 +78,9 @@ var templateFuncs = template.FuncMap{ return "sync" }, - "Exported": exported, + "exported": exported, - "MocksSomeMethod": func(mocks []MockData) bool { + "mocksSomeMethod": func(mocks []MockData) bool { for _, m := range mocks { if len(m.Methods) > 0 { return true @@ -81,7 +89,7 @@ var templateFuncs = template.FuncMap{ return false }, - "TypeConstraintTest": func(m MockData) string { + "typeConstraintTest": func(m MockData) string { if len(m.TypeParams) == 0 { return "" } @@ -97,45 +105,48 @@ var templateFuncs = template.FuncMap{ s += "]" return s }, +} + +var StringManipulationFuncs = template.FuncMap{ // String inspection and manipulation. Note that the first argument is replaced // as the last argument in some functions in order to support chained // template pipelines. - "Contains": func(substr string, s string) bool { return strings.Contains(s, substr) }, - "HasPrefix": func(prefix string, s string) bool { return strings.HasPrefix(s, prefix) }, - "HasSuffix": func(suffix string, s string) bool { return strings.HasSuffix(s, suffix) }, - "Join": func(sep string, elems []string) string { return strings.Join(elems, sep) }, - "Replace": func(old string, new string, n int, s string) string { return strings.Replace(s, old, new, n) }, - "ReplaceAll": func(old string, new string, s string) string { return strings.ReplaceAll(s, old, new) }, - "Split": func(sep string, s string) []string { return strings.Split(s, sep) }, - "SplitAfter": func(sep string, s string) []string { return strings.SplitAfter(s, sep) }, - "SplitAfterN": func(sep string, n int, s string) []string { return strings.SplitAfterN(s, sep, n) }, - "Trim": func(cutset string, s string) string { return strings.Trim(s, cutset) }, - "TrimLeft": func(cutset string, s string) string { return strings.TrimLeft(s, cutset) }, - "TrimPrefix": func(prefix string, s string) string { return strings.TrimPrefix(s, prefix) }, - "TrimRight": func(cutset string, s string) string { return strings.TrimRight(s, cutset) }, - "TrimSpace": strings.TrimSpace, - "TrimSuffix": func(suffix string, s string) string { return strings.TrimSuffix(s, suffix) }, - "Lower": strings.ToLower, - "Upper": strings.ToUpper, - "Camelcase": xstrings.ToCamelCase, - "Snakecase": xstrings.ToSnakeCase, - "Kebabcase": xstrings.ToKebabCase, - "FirstLower": xstrings.FirstRuneToLower, - "FirstUpper": xstrings.FirstRuneToUpper, + "contains": func(substr string, s string) bool { return strings.Contains(s, substr) }, + "hasPrefix": func(prefix string, s string) bool { return strings.HasPrefix(s, prefix) }, + "hasSuffix": func(suffix string, s string) bool { return strings.HasSuffix(s, suffix) }, + "join": func(sep string, elems []string) string { return strings.Join(elems, sep) }, + "replace": func(old string, new string, n int, s string) string { return strings.Replace(s, old, new, n) }, + "replaceAll": func(old string, new string, s string) string { return strings.ReplaceAll(s, old, new) }, + "split": func(sep string, s string) []string { return strings.Split(s, sep) }, + "splitAfter": func(sep string, s string) []string { return strings.SplitAfter(s, sep) }, + "splitAfterN": func(sep string, n int, s string) []string { return strings.SplitAfterN(s, sep, n) }, + "trim": func(cutset string, s string) string { return strings.Trim(s, cutset) }, + "trimLeft": func(cutset string, s string) string { return strings.TrimLeft(s, cutset) }, + "trimPrefix": func(prefix string, s string) string { return strings.TrimPrefix(s, prefix) }, + "trimRight": func(cutset string, s string) string { return strings.TrimRight(s, cutset) }, + "trimSpace": strings.TrimSpace, + "trimSuffix": func(suffix string, s string) string { return strings.TrimSuffix(s, suffix) }, + "lower": strings.ToLower, + "upper": strings.ToUpper, + "camelcase": xstrings.ToCamelCase, + "snakecase": xstrings.ToSnakeCase, + "kebabcase": xstrings.ToKebabCase, + "firstLower": xstrings.FirstRuneToLower, + "firstUpper": xstrings.FirstRuneToUpper, // Regular expression matching - "MatchString": regexp.MatchString, - "QuoteMeta": regexp.QuoteMeta, + "matchString": regexp.MatchString, + "quoteMeta": regexp.QuoteMeta, // Filepath manipulation - "Base": filepath.Base, - "Clean": filepath.Clean, - "Dir": filepath.Dir, + "base": filepath.Base, + "clean": filepath.Clean, + "dir": filepath.Dir, // Basic access to reading environment variables - "ExpandEnv": os.ExpandEnv, - "Getenv": os.Getenv, + "expandEnv": os.ExpandEnv, + "getenv": os.Getenv, // Arithmetic - "Add": func(i1, i2 int) int { return i1 + i2 }, + "add": func(i1, i2 int) int { return i1 + i2 }, } diff --git a/pkg/template/template_test.go b/pkg/template/template_test.go index b8c80a60..3252daa6 100644 --- a/pkg/template/template_test.go +++ b/pkg/template/template_test.go @@ -7,9 +7,9 @@ import ( "github.com/vektra/mockery/v3/pkg/registry" ) -func TestTemplateFuncs(t *testing.T) { +func TestTemplateMockFuncs(t *testing.T) { t.Run("Exported", func(t *testing.T) { - f := templateFuncs["Exported"].(func(string) string) + f := TemplateMockFuncs["exported"].(func(string) string) if f("") != "" { t.Errorf("Exported(...) want: ``; got: `%s`", f("")) } @@ -19,7 +19,7 @@ func TestTemplateFuncs(t *testing.T) { }) t.Run("ImportStatement", func(t *testing.T) { - f := templateFuncs["ImportStatement"].(func(*registry.Package) string) + f := TemplateMockFuncs["ImportStatement"].(func(*registry.Package) string) pkg := registry.NewPackage(types.NewPackage("xyz", "xyz")) if f(pkg) != `"xyz"` { t.Errorf("ImportStatement(...): want: `\"xyz\"`; got: `%s`", f(pkg)) @@ -32,7 +32,7 @@ func TestTemplateFuncs(t *testing.T) { }) t.Run("SyncPkgQualifier", func(t *testing.T) { - f := templateFuncs["SyncPkgQualifier"].(func([]*registry.Package) string) + f := TemplateMockFuncs["SyncPkgQualifier"].(func([]*registry.Package) string) if f(nil) != "sync" { t.Errorf("SyncPkgQualifier(...): want: `sync`; got: `%s`", f(nil)) }