diff --git a/.github/workflows/gpt_pullrequest_updater.yml b/.github/workflows/gpt_pullrequest_updater.yml index 0fb4d68..0f1184a 100644 --- a/.github/workflows/gpt_pullrequest_updater.yml +++ b/.github/workflows/gpt_pullrequest_updater.yml @@ -36,6 +36,7 @@ jobs: OWNER: ${{ github.repository_owner }} REPO: ${{ github.event.repository.name }} PR_NUMBER: ${{ github.event.number }} + OPENAI_MODEL: gpt-4 - name: Review Pull Request if: github.event.action == 'opened' @@ -47,3 +48,4 @@ jobs: OWNER: ${{ github.repository_owner }} REPO: ${{ github.event.repository.name }} PR_NUMBER: ${{ github.event.number }} + OPENAI_MODEL: gpt-4 diff --git a/cmd/description/main.go b/cmd/description/main.go index 38a122a..aef8e06 100644 --- a/cmd/description/main.go +++ b/cmd/description/main.go @@ -22,6 +22,7 @@ var opts struct { Owner string `long:"owner" env:"OWNER" description:"GitHub owner" required:"true"` Repo string `long:"repo" env:"REPO" description:"GitHub repo" required:"true"` PRNumber int `long:"pr-number" env:"PR_NUMBER" description:"Pull request number" required:"true"` + OpenAIModel string `long:"openai-model" env:"OPENAI_MODEL" description:"OpenAI model" default:"gpt-3.5-turbo"` Test bool `long:"test" env:"TEST" description:"Test mode"` JiraURL string `long:"jira-url" env:"JIRA_URL" description:"Jira URL. Example: https://jira.atlassian.com"` } @@ -43,7 +44,7 @@ func main() { } func run(ctx context.Context) error { - openAIClient := oAIClient.NewClient(opts.OpenAIToken) + openAIClient := oAIClient.NewClient(opts.OpenAIToken, opts.OpenAIModel) githubClient := ghClient.NewClient(ctx, opts.GithubToken) pr, err := githubClient.GetPullRequest(ctx, opts.Owner, opts.Repo, opts.PRNumber) diff --git a/cmd/review/main.go b/cmd/review/main.go index 096cfac..74ccde7 100644 --- a/cmd/review/main.go +++ b/cmd/review/main.go @@ -20,6 +20,7 @@ var opts struct { Owner string `long:"owner" env:"OWNER" description:"GitHub owner" required:"true"` Repo string `long:"repo" env:"REPO" description:"GitHub repo" required:"true"` PRNumber int `long:"pr-number" env:"PR_NUMBER" description:"Pull request number" required:"true"` + OpenAIModel string `long:"openai-model" env:"OPENAI_MODEL" description:"OpenAI model" default:"gpt-3.5-turbo"` Test bool `long:"test" env:"TEST" description:"Test mode"` } @@ -40,7 +41,7 @@ func main() { } func run(ctx context.Context) error { - openAIClient := oAIClient.NewClient(opts.OpenAIToken) + openAIClient := oAIClient.NewClient(opts.OpenAIToken, opts.OpenAIModel) githubClient := ghClient.NewClient(ctx, opts.GithubToken) pr, err := githubClient.GetPullRequest(ctx, opts.Owner, opts.Repo, opts.PRNumber) diff --git a/openai/openai.go b/openai/openai.go index b21af60..4081ad6 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -21,19 +21,21 @@ var PromptDescribeOverall string type Client struct { client *openai.Client + model string } -func NewClient(token string) *Client { +func NewClient(token, model string) *Client { return &Client{ client: openai.NewClient(token), + model: model, } } -func (o *Client) ChatCompletion(ctx context.Context, messages []openai.ChatCompletionMessage) (string, error) { - resp, err := o.client.CreateChatCompletion( +func (c *Client) ChatCompletion(ctx context.Context, messages []openai.ChatCompletionMessage) (string, error) { + resp, err := c.client.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + Model: c.model, Messages: messages, Temperature: 0.1, }, @@ -47,10 +49,10 @@ func (o *Client) ChatCompletion(ctx context.Context, messages []openai.ChatCompl fmt.Println("Retrying after 1 minute") // retry once after 1 minute time.Sleep(time.Minute) - resp, err = o.client.CreateChatCompletion( + resp, err = c.client.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + Model: c.model, Messages: messages, Temperature: 0.1, },