Skip to content

Commit bee0656

Browse files
NullpointerWW
and
W
authored
CompletionBatchingRequestSupport (sashabaranov#220)
* completionBatchingRequestSupport * lint fix * fix Run test fail * fix TestClientReturnsRequestBuilderErrors fail * fix Codecov check * ignore TestClientReturnsRequestBuilderErrors lint * fix lint again * lint again*2 * replace checkPromptType implementation * remove nil check --------- Co-authored-by: W <[email protected]>
1 parent b542086 commit bee0656

File tree

4 files changed

+47
-6
lines changed

4 files changed

+47
-6
lines changed

completion.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ import (
77
)
88

99
var (
10-
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
11-
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
10+
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
11+
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
12+
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Promp only supports string and []string") //nolint:lll
1213
)
1314

1415
// GPT3 Defines the models provided by OpenAI to use when generating
@@ -77,10 +78,16 @@ func checkEndpointSupportsModel(endpoint, model string) bool {
7778
return !disabledModelsForEndpoints[endpoint][model]
7879
}
7980

81+
func checkPromptType(prompt any) bool {
82+
_, isString := prompt.(string)
83+
_, isStringSlice := prompt.([]string)
84+
return isString || isStringSlice
85+
}
86+
8087
// CompletionRequest represents a request structure for completion API.
8188
type CompletionRequest struct {
8289
Model string `json:"model"`
83-
Prompt string `json:"prompt,omitempty"`
90+
Prompt any `json:"prompt,omitempty"`
8491
Suffix string `json:"suffix,omitempty"`
8592
MaxTokens int `json:"max_tokens,omitempty"`
8693
Temperature float32 `json:"temperature,omitempty"`
@@ -143,6 +150,11 @@ func (c *Client) CreateCompletion(
143150
return
144151
}
145152

153+
if !checkPromptType(request.Prompt) {
154+
err = ErrCompletionRequestPromptTypeNotSupported
155+
return
156+
}
157+
146158
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
147159
if err != nil {
148160
return

completion_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
9898
// generate a random string of length completionReq.Length
9999
completionStr := strings.Repeat("a", completionReq.MaxTokens)
100100
if completionReq.Echo {
101-
completionStr = completionReq.Prompt + completionStr
101+
completionStr = completionReq.Prompt.(string) + completionStr
102102
}
103103
res.Choices = append(res.Choices, CompletionChoice{
104104
Text: completionStr,
105105
Index: i,
106106
})
107107
}
108-
inputTokens := numTokens(completionReq.Prompt) * completionReq.N
108+
inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N
109109
completionTokens := completionReq.MaxTokens * completionReq.N
110110
res.Usage = Usage{
111111
PromptTokens: inputTokens,

request_builder_test.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
5151

5252
ctx := context.Background()
5353

54-
_, err = client.CreateCompletion(ctx, CompletionRequest{})
54+
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
5555
if !errors.Is(err, errTestRequestBuilderFailed) {
5656
t.Fatalf("Did not return error when request builder failed: %v", err)
5757
}
@@ -146,3 +146,27 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
146146
t.Fatalf("Did not return error when request builder failed: %v", err)
147147
}
148148
}
149+
150+
func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) {
151+
var err error
152+
ts := test.NewTestServer().OpenAITestServer()
153+
ts.Start()
154+
defer ts.Close()
155+
156+
config := DefaultConfig(test.GetTestToken())
157+
config.BaseURL = ts.URL + "/v1"
158+
client := NewClientWithConfig(config)
159+
client.requestBuilder = &failingRequestBuilder{}
160+
161+
ctx := context.Background()
162+
163+
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
164+
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
165+
t.Fatalf("Did not return error when request builder failed: %v", err)
166+
}
167+
168+
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
169+
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
170+
t.Fatalf("Did not return error when request builder failed: %v", err)
171+
}
172+
}

stream.go

+5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ func (c *Client) CreateCompletionStream(
2828
return
2929
}
3030

31+
if !checkPromptType(request.Prompt) {
32+
err = ErrCompletionRequestPromptTypeNotSupported
33+
return
34+
}
35+
3136
request.Stream = true
3237
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
3338
if err != nil {

0 commit comments

Comments
 (0)