From 68478f86689dc532cea09f50a0318385278cb955 Mon Sep 17 00:00:00 2001 From: Beyang Liu Date: Mon, 12 Aug 2024 17:47:38 -0700 Subject: [PATCH 1/4] wip: idf index --- .../graphqlbackend/repository_reindex.go | 14 +++++++ .../internal/context/resolvers/context.go | 12 ++++++ cmd/frontend/internal/search/idf/idf.go | 37 +++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 cmd/frontend/internal/search/idf/idf.go diff --git a/cmd/frontend/graphqlbackend/repository_reindex.go b/cmd/frontend/graphqlbackend/repository_reindex.go index 66e88a868797..653a574024f6 100644 --- a/cmd/frontend/graphqlbackend/repository_reindex.go +++ b/cmd/frontend/graphqlbackend/repository_reindex.go @@ -2,9 +2,11 @@ package graphqlbackend import ( "context" + "fmt" "github.com/graph-gophers/graphql-go" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf" "github.com/sourcegraph/sourcegraph/internal/auth" "github.com/sourcegraph/sourcegraph/internal/search/zoekt" ) @@ -13,11 +15,23 @@ import ( func (r *schemaResolver) ReindexRepository(ctx context.Context, args *struct { Repository graphql.ID }) (*EmptyResponse, error) { + // MARK(beyang): this is triggered by the "Reindex now" button on a page like https://sourcegraph.test:3443/github.com/hashicorp/errwrap/-/settings/index + fmt.Printf("# schemaResolver.ReindexRepository\n") + // 🚨 SECURITY: There is no reason why non-site-admins would need to run this operation. if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.db); err != nil { return nil, err } + repoID, err := UnmarshalRepositoryID(args.Repository) + if err != nil { + return nil, err + } + fmt.Printf("# schemaResolver.ReindexRepository repoID %s -> repoID %d\n", args.Repository, repoID) + if err := idf.Update(ctx, repoID, "foobar"); err != nil { + return nil, err + } + repo, err := r.repositoryByID(ctx, args.Repository) if err != nil { return nil, err diff --git a/cmd/frontend/internal/context/resolvers/context.go b/cmd/frontend/internal/context/resolvers/context.go index 207767ced54b..519095900bda 100644 --- a/cmd/frontend/internal/context/resolvers/context.go +++ b/cmd/frontend/internal/context/resolvers/context.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "time" @@ -13,9 +14,11 @@ import ( "github.com/sourcegraph/conc/iter" "github.com/sourcegraph/conc/pool" "github.com/sourcegraph/log" + "github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/cody" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf" "github.com/sourcegraph/sourcegraph/internal/api" "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/database" @@ -193,6 +196,15 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo repoNameIDs[i] = types.RepoIDName{ID: repoID, Name: repo.Name} } + for _, repoID := range repoIDs { + val, err := idf.Get(ctx, repoID) + if err != nil { + fmt.Printf("Unexpected error getting idf index value for repo %v: %v\n", repoID, err) + continue + } + fmt.Printf("# Got idf index value for repo %v: %v\n", repoID, string(val)) + } + fileChunks, err := r.contextClient.GetCodyContext(ctx, codycontext.GetContextArgs{ Repos: repoNameIDs, Query: args.Query, diff --git a/cmd/frontend/internal/search/idf/idf.go b/cmd/frontend/internal/search/idf/idf.go new file mode 100644 index 000000000000..51d5ab94ebb3 --- /dev/null +++ b/cmd/frontend/internal/search/idf/idf.go @@ -0,0 +1,37 @@ +// TODO(beyang): should probably move this elsewhere +package idf + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/sourcegraph/sourcegraph/internal/api" + "github.com/sourcegraph/sourcegraph/internal/rcache" + "github.com/sourcegraph/sourcegraph/internal/redispool" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + + +var redisCache = rcache.NewWithTTL(redispool.Cache, "idf-index", 10*24*60*60) + + +func Update(ctx context.Context, repoID api.RepoID, value interface{}) error { + fmt.Printf("# idf.Update(%v)", repoID) + b, err := json.Marshal(value) + if err != nil { + return errors.Wrap(err, "idf.Update") + } + redisCache.Set(fmt.Sprintf("repo:%v", repoID), b) + return nil +} + + +func Get(ctx context.Context, repoID api.RepoID) ([]byte, error) { + fmt.Printf("# idf.Get(%v)", repoID) + b, ok := redisCache.Get(fmt.Sprintf("repo:%v", repoID)) + if !ok { + return nil, fmt.Errorf("idf.Get: repo %s not found", string(repoID)) + } + return b, nil +} From 357e33486b61bb9802439ba39f9428132fcd1358 Mon Sep 17 00:00:00 2001 From: Beyang Liu Date: Tue, 13 Aug 2024 15:10:50 -0700 Subject: [PATCH 2/4] add temp package doc --- cmd/frontend/internal/search/idf/idf.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/frontend/internal/search/idf/idf.go b/cmd/frontend/internal/search/idf/idf.go index 51d5ab94ebb3..7381ef9a3ef6 100644 --- a/cmd/frontend/internal/search/idf/idf.go +++ b/cmd/frontend/internal/search/idf/idf.go @@ -1,3 +1,5 @@ +// Package idf computes and stores the inverse document frequency (IDF) of a set of repositories. +// // TODO(beyang): should probably move this elsewhere package idf From 9b5a3bb928d410c834ee3e15d202f8a688508a46 Mon Sep 17 00:00:00 2001 From: Beyang Liu Date: Wed, 14 Aug 2024 13:55:14 -0700 Subject: [PATCH 3/4] store, then use basic idf table --- .../graphqlbackend/repository_reindex.go | 9 +- cmd/frontend/internal/codycontext/context.go | 47 +++++++- .../internal/context/resolvers/context.go | 12 +- cmd/frontend/internal/search/idf/idf.go | 110 ++++++++++++++++-- cmd/frontend/internal/search/idf/tokenize.go | 63 ++++++++++ .../internal/search/idf/tokenize_test.go | 69 +++++++++++ 6 files changed, 282 insertions(+), 28 deletions(-) create mode 100644 cmd/frontend/internal/search/idf/tokenize.go create mode 100644 cmd/frontend/internal/search/idf/tokenize_test.go diff --git a/cmd/frontend/graphqlbackend/repository_reindex.go b/cmd/frontend/graphqlbackend/repository_reindex.go index 653a574024f6..352445b70786 100644 --- a/cmd/frontend/graphqlbackend/repository_reindex.go +++ b/cmd/frontend/graphqlbackend/repository_reindex.go @@ -23,17 +23,12 @@ func (r *schemaResolver) ReindexRepository(ctx context.Context, args *struct { return nil, err } - repoID, err := UnmarshalRepositoryID(args.Repository) + repo, err := r.repositoryByID(ctx, args.Repository) if err != nil { return nil, err } - fmt.Printf("# schemaResolver.ReindexRepository repoID %s -> repoID %d\n", args.Repository, repoID) - if err := idf.Update(ctx, repoID, "foobar"); err != nil { - return nil, err - } - repo, err := r.repositoryByID(ctx, args.Repository) - if err != nil { + if err := idf.Update(ctx, repo.RepoName()); err != nil { return nil, err } diff --git a/cmd/frontend/internal/codycontext/context.go b/cmd/frontend/internal/codycontext/context.go index a0906e58561d..f168455285e2 100644 --- a/cmd/frontend/internal/codycontext/context.go +++ b/cmd/frontend/internal/codycontext/context.go @@ -6,12 +6,15 @@ import ( "strings" "sync" + lg "log" + "github.com/grafana/regexp" "github.com/sourcegraph/conc/pool" "github.com/sourcegraph/log" "go.opentelemetry.io/otel/attribute" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/cody" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf" "github.com/sourcegraph/sourcegraph/internal/api" "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/database" @@ -82,6 +85,7 @@ type CodyContextClient struct { type GetContextArgs struct { Repos []types.RepoIDName + RepoStats map[api.RepoName]*idf.StatsProvider Query string CodeResultsCount int32 TextResultsCount int32 @@ -138,13 +142,15 @@ func (c *CodyContextClient) GetCodyContext(ctx context.Context, args GetContextA embeddingsArgs := GetContextArgs{ Repos: embeddingRepos, + RepoStats: args.RepoStats, Query: args.Query, CodeResultsCount: int32(float32(args.CodeResultsCount) * embeddingsResultRatio), TextResultsCount: int32(float32(args.TextResultsCount) * embeddingsResultRatio), } keywordArgs := GetContextArgs{ - Repos: keywordRepos, - Query: args.Query, + Repos: keywordRepos, + RepoStats: args.RepoStats, + Query: args.Query, // Assign the remaining result budget to keyword search CodeResultsCount: args.CodeResultsCount - embeddingsArgs.CodeResultsCount, TextResultsCount: args.TextResultsCount - embeddingsArgs.TextResultsCount, @@ -277,7 +283,9 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte // mini-HACK: pass in the scope using repo: filters. In an ideal world, we // would not be using query text manipulation for this and would be using // the job structs directly. - keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), args.Query) + transformedQuery := getTransformedQuery(args) + lg.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery) + keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), transformedQuery) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -371,3 +379,36 @@ func fileMatchToContextMatch(fm *result.FileMatch) FileChunkContext { StartLine: startLine, } } + +func getTransformedQuery(args GetContextArgs) string { + if args.RepoStats == nil { + lg.Printf("# no stats set") + return args.Query + } + + for _, repo := range args.Repos { + if _, ok := args.RepoStats[repo.Name]; !ok { + // Don't transform query if one of the repositories lacks an IDF table + lg.Printf("# didn't find stats for repo %s", repo.Name) + return args.Query + } + } + + // TODO(beyang): the semantics of what we want to do here aren't super clear. + // Do we want to preserve the wholeness of the "words" the user types in (the tokens are camelcased tokenized). + // Probably, otherwise the transformed query will yield noisier results. + queryToks := idf.Tokenize(args.Query) + var filteredToks []string + + const idfThresh = 0.2 + for _, qtok := range queryToks { + + for _, stats := range args.RepoStats { + if stats.GetIDF(qtok) < idfThresh { + continue + } + filteredToks = append(filteredToks, qtok) + } + } + return strings.Join(filteredToks, " ") +} diff --git a/cmd/frontend/internal/context/resolvers/context.go b/cmd/frontend/internal/context/resolvers/context.go index 519095900bda..5698b434ede9 100644 --- a/cmd/frontend/internal/context/resolvers/context.go +++ b/cmd/frontend/internal/context/resolvers/context.go @@ -4,8 +4,8 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" + lg "log" "net/http" "time" @@ -186,6 +186,7 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo } repoNameIDs := make([]types.RepoIDName, len(repoIDs)) + repoStats := make(map[api.RepoName]*idf.StatsProvider) for i, repoID := range repoIDs { repo, ok := repos[repoID] if !ok { @@ -194,19 +195,18 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo } repoNameIDs[i] = types.RepoIDName{ID: repoID, Name: repo.Name} - } - for _, repoID := range repoIDs { - val, err := idf.Get(ctx, repoID) + stats, err := idf.Get(ctx, repo.Name) if err != nil { - fmt.Printf("Unexpected error getting idf index value for repo %v: %v\n", repoID, err) + lg.Printf("Unexpected error getting idf index value for repo %v: %v", repoID, err) continue } - fmt.Printf("# Got idf index value for repo %v: %v\n", repoID, string(val)) + repoStats[repo.Name] = stats } fileChunks, err := r.contextClient.GetCodyContext(ctx, codycontext.GetContextArgs{ Repos: repoNameIDs, + RepoStats: repoStats, Query: args.Query, CodeResultsCount: args.CodeResultsCount, TextResultsCount: args.TextResultsCount, diff --git a/cmd/frontend/internal/search/idf/idf.go b/cmd/frontend/internal/search/idf/idf.go index 7381ef9a3ef6..3705ec70db6c 100644 --- a/cmd/frontend/internal/search/idf/idf.go +++ b/cmd/frontend/internal/search/idf/idf.go @@ -4,36 +4,122 @@ package idf import ( + "archive/tar" + "bufio" "context" "encoding/json" "fmt" + "io" + "log" + "math" + "strings" "github.com/sourcegraph/sourcegraph/internal/api" + "github.com/sourcegraph/sourcegraph/internal/gitserver" "github.com/sourcegraph/sourcegraph/internal/rcache" "github.com/sourcegraph/sourcegraph/internal/redispool" "github.com/sourcegraph/sourcegraph/lib/errors" ) - var redisCache = rcache.NewWithTTL(redispool.Cache, "idf-index", 10*24*60*60) +func Update(ctx context.Context, repoName api.RepoName) error { + fmt.Printf("# idf.Update(%v)\n", repoName) + + stats := NewStatsAggregator() + + git := gitserver.NewClient("idf-indexer") + r, err := git.ArchiveReader(ctx, repoName, gitserver.ArchiveOptions{Treeish: "HEAD", Format: gitserver.ArchiveFormatTar}) + if err != nil { + return nil + } + + tr := tar.NewReader(r) + for { + header, err := tr.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + log.Printf("Error reading next tar header: %v", err) + continue + } + + // Skip directories + if header.Typeflag == tar.TypeDir { + continue + } + + // Read the first line of the file + scanner := bufio.NewScanner(tr) + if scanner.Scan() { + stats.ProcessDoc(scanner.Text()) + } else if err := scanner.Err(); err != nil { + log.Printf("Error reading file content: %v", err) + } + } + + statsP := stats.EvalProvider() + statsBytes, err := json.Marshal(statsP) + + log.Printf("# storing stats: %s", string(statsBytes)) -func Update(ctx context.Context, repoID api.RepoID, value interface{}) error { - fmt.Printf("# idf.Update(%v)", repoID) - b, err := json.Marshal(value) if err != nil { - return errors.Wrap(err, "idf.Update") + return errors.Wrap(err, "idf.Update: failed to marshal IDF table") } - redisCache.Set(fmt.Sprintf("repo:%v", repoID), b) + + redisCache.Set(fmt.Sprintf("repo:%v", repoName), statsBytes) return nil } - -func Get(ctx context.Context, repoID api.RepoID) ([]byte, error) { - fmt.Printf("# idf.Get(%v)", repoID) - b, ok := redisCache.Get(fmt.Sprintf("repo:%v", repoID)) +func Get(ctx context.Context, repoName api.RepoName) (*StatsProvider, error) { + fmt.Printf("# idf.Get(%v)", repoName) + b, ok := redisCache.Get(fmt.Sprintf("repo:%v", repoName)) if !ok { - return nil, fmt.Errorf("idf.Get: repo %s not found", string(repoID)) + return nil, nil + } + + var stats StatsProvider + if err := json.Unmarshal(b, &stats); err != nil { + return nil, errors.Wrap(err, "idf.Get: failed to unmarshal IDF table") + } + + log.Printf("# fetching stats: %v", stats) + + return &stats, nil +} + +type StatsAggregator struct { + TermToDocCt map[string]int + DoctCt int +} + +func NewStatsAggregator() *StatsAggregator { + return &StatsAggregator{ + TermToDocCt: make(map[string]int), + } +} + +func (s *StatsAggregator) ProcessDoc(text string) { + for _, tok := range Tokenize(text) { + term := strings.ToLower((tok)) + s.TermToDocCt[term]++ + } + s.DoctCt++ +} + +func (s *StatsAggregator) EvalProvider() StatsProvider { + idf := make(map[string]float32) + for term, docCt := range s.TermToDocCt { + idf[term] = float32(math.Log(float64(s.DoctCt) / (1.0 + float64(docCt)))) } - return b, nil + return StatsProvider{IDF: idf} +} + +type StatsProvider struct { + IDF map[string]float32 +} + +func (s *StatsProvider) GetIDF(term string) float32 { + return s.IDF[strings.ToLower(term)] } diff --git a/cmd/frontend/internal/search/idf/tokenize.go b/cmd/frontend/internal/search/idf/tokenize.go new file mode 100644 index 000000000000..910a28f7fc09 --- /dev/null +++ b/cmd/frontend/internal/search/idf/tokenize.go @@ -0,0 +1,63 @@ +package idf + +import ( + "regexp" + "strings" +) + +var ( + camelStartRe = regexp.MustCompile(`^[A-Za-z][^A-Z]+`) + capStartRe = regexp.MustCompile(`^[A-Z][A-Z0-9]*`) +) + +func tokenizeCamelCase(s string) []string { + remainder := s + var toks []string + for len(remainder) > 0 { + if found := camelStartRe.FindString(remainder); found != "" { + toks = append(toks, found) + remainder = remainder[len(found):] + continue + } + if found := capStartRe.FindString(remainder); found != "" { + if len(found) == 1 || len(found) == len(remainder) { + toks = append(toks, found) + remainder = remainder[len(found):] + } else { + toks = append(toks, found[:len(found)-1]) + remainder = remainder[len(found)-1:] + } + continue + } + remainder = remainder[1:] + } + return toks +} + +func tokenizeSnakeCase(s string) []string { + return strings.Split(s, "_") +} + +var ( + sepRe = regexp.MustCompile(`([[:punct:]]|\s)+`) +) + +func TokenizeWord(w string) []string { + var toks []string + for _, part := range tokenizeSnakeCase(w) { + toks = append(toks, tokenizeCamelCase(part)...) + } + return toks +} + +func Tokenize(s string) []string { + var toks []string + for _, word := range Words(s) { + toks = append(toks, TokenizeWord(word)...) + } + return toks +} + +func Words(s string) []string { + return sepRe.Split(s, -1) +} diff --git a/cmd/frontend/internal/search/idf/tokenize_test.go b/cmd/frontend/internal/search/idf/tokenize_test.go new file mode 100644 index 000000000000..d03b37295de1 --- /dev/null +++ b/cmd/frontend/internal/search/idf/tokenize_test.go @@ -0,0 +1,69 @@ +package idf + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestTokenizeCamelCase(t *testing.T) { + type testCase struct { + s string + expToks []string + } + cases := []testCase{ + { + s: "FooBar", + expToks: []string{"Foo", "Bar"}, + }, + { + s: "fooBarBaz", + expToks: []string{"foo", "Bar", "Baz"}, + }, + { + s: "HTMLParser", + expToks: []string{"HTML", "Parser"}, + }, + { + s: "parseHTML", + expToks: []string{"parse", "HTML"}, + }, + { + s: "HTML5Parser", + expToks: []string{"HTML5", "Parser"}, + }, + { + s: "parseHTML5", + expToks: []string{"parse", "HTML5"}, + }, + } + for _, c := range cases { + toks := tokenizeCamelCase(c.s) + if diff := cmp.Diff(toks, c.expToks); diff != "" { + t.Errorf(diff) + } + } +} + +func TestTokenize(t *testing.T) { + type testCase struct { + s string + expToks []string + } + cases := []testCase{ + { + s: "camelCase.snake_case + _weird_.", + expToks: []string{"camel", "Case", "snake", "case", "weird"}, + }, + { + s: "two words camelCase--!:@withPunctuation and_snake_case", + expToks: []string{"two", "words", "camel", "Case", "with", "Punctuation", "and", "snake", "case"}, + }, + } + for _, c := range cases { + toks := Tokenize(c.s) + if diff := cmp.Diff(c.expToks, toks); diff != "" { + t.Errorf(diff) + } + } +} From 02ad6ea1eb1b36627748c01c1991698e098319bc Mon Sep 17 00:00:00 2001 From: Rishabh Mehrotra Date: Fri, 16 Aug 2024 01:57:07 +0100 Subject: [PATCH 4/4] editing idf index to remove tokenization and edited matching logic for term expansion --- cmd/frontend/internal/codycontext/context.go | 40 ++++++++++++----- cmd/frontend/internal/search/idf/idf.go | 45 ++++++++++++++++++-- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/cmd/frontend/internal/codycontext/context.go b/cmd/frontend/internal/codycontext/context.go index f168455285e2..24a0b53398c0 100644 --- a/cmd/frontend/internal/codycontext/context.go +++ b/cmd/frontend/internal/codycontext/context.go @@ -3,6 +3,7 @@ package codycontext import ( "context" "fmt" + "sort" "strings" "sync" @@ -283,8 +284,10 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte // mini-HACK: pass in the scope using repo: filters. In an ideal world, we // would not be using query text manipulation for this and would be using // the job structs directly. - transformedQuery := getTransformedQuery(args) + var maxTermsPerWord = 5 + transformedQuery := getTransformedQuery(args, maxTermsPerWord) lg.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery) + fmt.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery) keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), transformedQuery) ctx, cancel := context.WithCancel(ctx) @@ -380,7 +383,7 @@ func fileMatchToContextMatch(fm *result.FileMatch) FileChunkContext { } } -func getTransformedQuery(args GetContextArgs) string { +func getTransformedQuery(args GetContextArgs, maxTermsPerWord int) string { if args.RepoStats == nil { lg.Printf("# no stats set") return args.Query @@ -394,21 +397,36 @@ func getTransformedQuery(args GetContextArgs) string { } } - // TODO(beyang): the semantics of what we want to do here aren't super clear. - // Do we want to preserve the wholeness of the "words" the user types in (the tokens are camelcased tokenized). - // Probably, otherwise the transformed query will yield noisier results. - queryToks := idf.Tokenize(args.Query) + // TODO(rishabh): currently we are just picking up top-k vocab terms based on idf scores, but we can do a better semantic ranking of terms + // current matching is fairly limited based on substring matching, but perhaps stemming/lemmatization might be considered? + var filteredToks []string + // var maxTermsPerWord = 5 - const idfThresh = 0.2 - for _, qtok := range queryToks { + type termScore struct { + term string + score float32 + } + for _, word := range strings.Fields(args.Query) { + if len(word) < 4 { + continue + } + var matches []termScore for _, stats := range args.RepoStats { - if stats.GetIDF(qtok) < idfThresh { - continue + for term, score := range stats.GetTerms() { + if strings.Contains(term, word) && len(term) > 4 && score > 3 { + matches = append(matches, termScore{term: term, score: score}) + } } - filteredToks = append(filteredToks, qtok) + } + sort.Slice(matches, func(i, j int) bool { + return matches[i].score > matches[j].score + }) + for i := 0; i < min(maxTermsPerWord, len(matches)); i++ { + filteredToks = append(filteredToks, matches[i].term) } } + return strings.Join(filteredToks, " ") } diff --git a/cmd/frontend/internal/search/idf/idf.go b/cmd/frontend/internal/search/idf/idf.go index 3705ec70db6c..b5347984713d 100644 --- a/cmd/frontend/internal/search/idf/idf.go +++ b/cmd/frontend/internal/search/idf/idf.go @@ -12,7 +12,9 @@ import ( "io" "log" "math" + "path" "strings" + "unicode" "github.com/sourcegraph/sourcegraph/internal/api" "github.com/sourcegraph/sourcegraph/internal/gitserver" @@ -34,6 +36,13 @@ func Update(ctx context.Context, repoName api.RepoName) error { return nil } + permissibleExtensions := map[string]bool{ + ".py": true, ".js": true, ".ts": true, ".java": true, ".cpp": true, + ".c": true, ".cs": true, ".go": true, ".rb": true, ".rs": true, + ".php": true, ".html": true, ".css": true, ".scss": true, ".md": true, + ".sh": true, ".swift": true, ".kt": true, ".m": true, + } + tr := tar.NewReader(r) for { header, err := tr.Next() @@ -50,6 +59,13 @@ func Update(ctx context.Context, repoName api.RepoName) error { continue } + // Check if the file has a permissible extension + ext := strings.ToLower(path.Ext(header.Name)) + + if !permissibleExtensions[ext] { + continue + } + // Read the first line of the file scanner := bufio.NewScanner(tr) if scanner.Scan() { @@ -100,10 +116,29 @@ func NewStatsAggregator() *StatsAggregator { } } +func isValidWord(word string) bool { + if len(word) < 3 || len(word) > 20 { + return false + } + hasLetter := false + for _, char := range word { + if !unicode.IsLetter(char) && !unicode.IsNumber(char) { + return false + } + if unicode.IsLetter(char) { + hasLetter = true + } + } + return hasLetter +} + func (s *StatsAggregator) ProcessDoc(text string) { - for _, tok := range Tokenize(text) { - term := strings.ToLower((tok)) - s.TermToDocCt[term]++ + words := strings.Fields(text) + for _, word := range words { + // word = strings.ToLower(word) + if isValidWord(word) { + s.TermToDocCt[word]++ + } } s.DoctCt++ } @@ -123,3 +158,7 @@ type StatsProvider struct { func (s *StatsProvider) GetIDF(term string) float32 { return s.IDF[strings.ToLower(term)] } + +func (s *StatsProvider) GetTerms() map[string]float32 { + return s.IDF +}