Skip to content

Commit

Permalink
Simplify with "slices" and "maps" instead of "sort"
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandear committed Jan 14, 2025
1 parent 0f55ecb commit 22f013f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 29 deletions.
17 changes: 5 additions & 12 deletions pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
"go/ast"
"go/types"
"io"
"maps"
"os"
"path/filepath"
"regexp"
"sort"
"slices"
"strings"
"text/template"
"unicode"
Expand Down Expand Up @@ -391,26 +392,18 @@ func (g *Generator) expecterName() string {
return g.mockName() + "_Expecter"
}

func (g *Generator) sortedImportNames() (importNames []string) {
for name := range g.nameToPackagePath {
importNames = append(importNames, name)
}
sort.Strings(importNames)
return
}

func (g *Generator) generateImports(ctx context.Context) {
log := zerolog.Ctx(ctx)

log.Debug().Msgf("generating imports")

pkgPath := g.nameToPackagePath[g.iface.Pkg.Name()]
// Sort by import name so that we get a deterministic order
for _, name := range g.sortedImportNames() {
logImport := log.With().Str(logging.LogKeyImport, g.nameToPackagePath[name]).Logger()
for _, name := range slices.Sorted(maps.Keys(g.nameToPackagePath)) {
path := g.nameToPackagePath[name]
logImport := log.With().Str(logging.LogKeyImport, path).Logger()
logImport.Debug().Msgf("found import")

path := g.nameToPackagePath[name]
if !g.config.KeepTree && g.config.InPackage && path == pkgPath {
logImport.Debug().Msgf("import (%s) equals interface's package path (%s), skipping", path, pkgPath)
continue
Expand Down
22 changes: 5 additions & 17 deletions pkg/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"go/types"
"os"
"path/filepath"
"sort"
"slices"
"strings"

"github.com/rs/zerolog"
Expand Down Expand Up @@ -224,13 +224,15 @@ func (p *Parser) Find(name string) (*Interface, error) {
}

func (p *Parser) Interfaces() []*Interface {
ifaces := make(sortableIFaceList, 0)
var ifaces []*Interface
for _, entry := range p.files {
declaredIfaces := entry.interfaces
ifaces = p.packageInterfaces(entry.pkg.Types, entry.fileName, declaredIfaces, ifaces)
}

sort.Sort(ifaces)
slices.SortFunc(ifaces, func(a, b *Interface) int {
return strings.Compare(a.Name, b.Name)
})
return ifaces
}

Expand Down Expand Up @@ -335,20 +337,6 @@ func (iface *Interface) Methods() []*Method {
return methods
}

type sortableIFaceList []*Interface

func (s sortableIFaceList) Len() int {
return len(s)
}

func (s sortableIFaceList) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}

func (s sortableIFaceList) Less(i, j int) bool {
return strings.Compare(s[i].Name, s[j].Name) == -1
}

type NodeVisitor struct {
declaredInterfaces []string
disableFuncMocks bool
Expand Down

0 comments on commit 22f013f

Please sign in to comment.