diff --git a/cmd/description/main.go b/cmd/description/main.go index 7136307..c9d270f 100644 --- a/cmd/description/main.go +++ b/cmd/description/main.go @@ -12,6 +12,7 @@ import ( "github.com/sashabaranov/go-openai" ghClient "github.com/ravilushqa/gpt-pullrequest-updater/github" + "github.com/ravilushqa/gpt-pullrequest-updater/jira" oAIClient "github.com/ravilushqa/gpt-pullrequest-updater/openai" ) @@ -22,6 +23,7 @@ var opts struct { Repo string `long:"repo" env:"REPO" description:"GitHub repo" required:"true"` PRNumber int `long:"pr-number" env:"PR_NUMBER" description:"Pull request number" required:"true"` Test bool `long:"test" env:"TEST" description:"Test mode"` + JiraURL string `long:"jira-url" env:"JIRA_URL" description:"Jira URL"` } func main() { @@ -35,6 +37,10 @@ func main() { os.Exit(0) } + if opts.Test { + fmt.Println("Test mode") + } + if err := run(ctx); err != nil { panic(err) } @@ -54,55 +60,110 @@ func run(ctx context.Context) error { return fmt.Errorf("error getting commits: %w", err) } - var OverallDescribeCompletion string - OverallDescribeCompletion += fmt.Sprintf("Pull request title: %s, body: %s\n\n", pr.GetTitle(), pr.GetBody()) + var sumDiffs int for _, file := range diff.Files { + sumDiffs += len(*file.Patch) + } + var completion string + if sumDiffs < 4000 { + completion, err = genCompletionOnce(ctx, openAIClient, diff) + if err != nil { + return fmt.Errorf("error generating completition once: %w", err) + } + } else { + completion, err = genCompletionPerFile(ctx, openAIClient, diff, pr) + if err != nil { + return fmt.Errorf("error generating completition twice: %w", err) + } + } + + if opts.JiraURL != "" { + fmt.Println("Adding Jira ticket") + id, err := jira.ExtractJiraTicketID(*pr.Title) + if err != nil { + fmt.Printf("Error extracting Jira ticket ID: %v \n", err) + } else { + completion = fmt.Sprintf("### JIRA ticket: [%s](%s) \n\n%s", id, jira.GenerateJiraTicketURL(opts.JiraURL, id), completion) + } + } + + if opts.Test { + fmt.Println(completion) + return nil + } + + // Update the pull request description + fmt.Println("Updating pull request") + updatePr := &github.PullRequest{Body: github.String(completion)} + if _, err = githubClient.UpdatePullRequest(ctx, opts.Owner, opts.Repo, opts.PRNumber, updatePr); err != nil { + return fmt.Errorf("error updating pull request: %w", err) + } + + return nil +} + +func genCompletionOnce(ctx context.Context, client *oAIClient.Client, diff *github.CommitsComparison) (string, error) { + fmt.Println("Generating completion once") + messages := make([]openai.ChatCompletionMessage, 0, len(diff.Files)) + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: oAIClient.PromptDescribeChanges, + }) + for _, file := range diff.Files { if file.Patch == nil { continue } + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: *file.Patch, + }) + } + + fmt.Println("Sending prompt to OpenAI") + completion, err := client.ChatCompletion(ctx, messages) + if err != nil { + return "", fmt.Errorf("error completing prompt: %w", err) + } + + return completion, nil +} + +func genCompletionPerFile(ctx context.Context, client *oAIClient.Client, diff *github.CommitsComparison, pr *github.PullRequest) (string, error) { + fmt.Println("Generating completion per file") + OverallDescribeCompletion := fmt.Sprintf("Pull request title: %s, body: %s\n\n", pr.GetTitle(), pr.GetBody()) + + for i, file := range diff.Files { prompt := fmt.Sprintf(oAIClient.PromptDescribeChanges, *file.Patch) if len(prompt) > 4096 { prompt = fmt.Sprintf("%s...", prompt[:4093]) } - completion, err := openAIClient.ChatCompletion(ctx, []openai.ChatCompletionMessage{ + fmt.Printf("Sending prompt to OpenAI for file %d/%d\n", i+1, len(diff.Files)) + completion, err := client.ChatCompletion(ctx, []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: prompt, }, }) if err != nil { - return fmt.Errorf("error getting review: %w", err) + return "", fmt.Errorf("error getting review: %w", err) } OverallDescribeCompletion += fmt.Sprintf("File: %s \nDescription: %s \n\n", file.GetFilename(), completion) } - overallCompletion, err := openAIClient.ChatCompletion(ctx, []openai.ChatCompletionMessage{ + fmt.Println("Sending final prompt to OpenAI") + overallCompletion, err := client.ChatCompletion(ctx, []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: fmt.Sprintf(oAIClient.PromptOverallDescribe, OverallDescribeCompletion), }, }) if err != nil { - return fmt.Errorf("error getting overall review: %w", err) - } - - if opts.Test { - fmt.Println(OverallDescribeCompletion) - fmt.Println("=====================================") - fmt.Println(overallCompletion) - - return nil + return "", fmt.Errorf("error getting overall review: %w", err) } - // Update the pull request description - updatePr := &github.PullRequest{Body: github.String(overallCompletion)} - if _, err = githubClient.UpdatePullRequest(ctx, opts.Owner, opts.Repo, opts.PRNumber, updatePr); err != nil { - return fmt.Errorf("error updating pull request: %w", err) - } - - return nil + return overallCompletion, nil } diff --git a/cmd/review/main.go b/cmd/review/main.go index 1190b84..ed0f881 100644 --- a/cmd/review/main.go +++ b/cmd/review/main.go @@ -57,11 +57,7 @@ func run(ctx context.Context) error { var OverallReviewCompletion string for _, file := range diff.Files { - if file.GetStatus() == "removed" || file.GetStatus() == "renamed" { - continue - } - - if file.Patch == nil { + if file.Patch == nil || file.GetStatus() == "removed" || file.GetStatus() == "renamed" { continue } diff --git a/jira/jira.go b/jira/jira.go new file mode 100644 index 0000000..c57ef94 --- /dev/null +++ b/jira/jira.go @@ -0,0 +1,29 @@ +package jira + +import ( + "fmt" + "regexp" +) + +const ticketUrlFormat = "%s/browse/%s" + +// ExtractJiraTicketID returns the first JIRA ticket ID found in the input string. +func ExtractJiraTicketID(s string) (string, error) { + // This regular expression pattern matches a JIRA ticket ID (e.g. PROJ-123). + pattern := `([aA-zZ]+-\d+)` + re, err := regexp.Compile(pattern) + if err != nil { + return "", fmt.Errorf("error compiling regex: %w", err) + } + + matches := re.FindStringSubmatch(s) + if len(matches) == 0 { + return "", fmt.Errorf("no JIRA ticket ID found in the input string") + } + + return matches[0], nil +} + +func GenerateJiraTicketURL(jiraURL, ticketID string) string { + return fmt.Sprintf(ticketUrlFormat, jiraURL, ticketID) +} diff --git a/jira/jira_test.go b/jira/jira_test.go new file mode 100644 index 0000000..7d9548a --- /dev/null +++ b/jira/jira_test.go @@ -0,0 +1,55 @@ +package jira + +import "testing" + +func TestExtractJiraTicketID(t *testing.T) { + testCases := []struct { + name string + input string + expected string + expectError bool + }{ + { + name: "Valid ticket ID", + input: "This is a sample text with a JIRA ticket ID: PROJ-123, let's extract it.", + expected: "PROJ-123", + expectError: false, + }, + { + name: "No ticket ID", + input: "This is a sample text without a JIRA ticket ID.", + expectError: true, + }, + { + name: "Multiple ticket IDs", + input: "This text has multiple JIRA ticket IDs: PROJ-123, TASK-456, and BUG-789.", + expected: "PROJ-123", + expectError: false, + }, + { + name: "Valid ticket ID. Lowercase.", + input: "This is an invalid JIRA ticket ID: Proj-123.", + expected: "Proj-123", + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := ExtractJiraTicketID(tc.input) + if tc.expectError { + if err == nil { + t.Errorf("expected an error, but got none") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if result != tc.expected { + t.Errorf("expected result '%s', but got '%s'", tc.expected, result) + } + } + }) + } +}