diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index c3c325c54b..99c2544878 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -185,7 +185,7 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - inputs, err := parseInputFlags(inputFlags) + inputs, err := parseInputFlags(inputFlags, schema) if err != nil { return err } @@ -361,7 +361,7 @@ func writeDataURLOutput(outputString string, outputPath string, addExtension boo return nil } -func parseInputFlags(inputs []string) (predict.Inputs, error) { +func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error) { keyVals := map[string][]string{} for _, input := range inputs { var name, value string @@ -383,7 +383,7 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) { keyVals[name] = append(keyVals[name], value) } - return predict.NewInputs(keyVals), nil + return predict.NewInputs(keyVals, schema) } func addSetupTimeoutFlag(cmd *cobra.Command) { diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 7965c79431..9e21709f87 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -1,11 +1,13 @@ package predict import ( + "encoding/json" "fmt" "os" "path/filepath" "strings" + "github.com/getkin/kin-openapi/openapi3" "github.com/mitchellh/go-homedir" "github.com/vincent-petithory/dataurl" @@ -13,22 +15,69 @@ import ( ) type Input struct { - String *string - File *string - Array *[]any + String *string + File *string + Array *[]any + ChatMessage *json.RawMessage } type Inputs map[string]Input -func NewInputs(keyVals map[string][]string) Inputs { +var jsonSerializableSchemas = map[string]bool{ + "#/components/schemas/CommonChatSchemaDeveloperMessage": true, + "#/components/schemas/CommonChatSchemaSystemMessage": true, + "#/components/schemas/CommonChatSchemaUserMessage": true, + "#/components/schemas/CommonChatSchemaAssistantMessage": true, + "#/components/schemas/CommonChatSchemaToolMessage": true, + "#/components/schemas/CommonChatSchemaFunctionMessage": true, +} + +func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) { + var inputComponent *openapi3.SchemaRef + for name, component := range schema.Components.Schemas { + if name == "Input" { + inputComponent = component + break + } + } + input := Inputs{} for key, vals := range keyVals { if len(vals) == 1 { val := vals[0] - if strings.HasPrefix(val, "@") { + + // Check if we should explicitly parse the JSON based on a known schema + if inputComponent != nil { + properties, err := inputComponent.JSONLookup("properties") + if err != nil { + return input, err + } + propertiesSchemas := properties.(openapi3.Schemas) + messages, err := propertiesSchemas.JSONLookup("messages") + // If there is an error it means messages was not found, this is valid for an OpenAPI schema. + if err == nil { + messagesSchemas := messages.(*openapi3.Schema) + found := false + for _, schemaRef := range messagesSchemas.Items.Value.AnyOf { + if _, ok := jsonSerializableSchemas[schemaRef.Ref]; ok { + found = true + message := json.RawMessage(val) + input[key] = Input{ChatMessage: &message} + break + } + } + if found { + continue + } + } + } + + switch { + case strings.HasPrefix(val, "@"): val = val[1:] input[key] = Input{File: &val} - } else { + + default: input[key] = Input{String: &val} } } else if len(vals) > 1 { @@ -39,7 +88,7 @@ func NewInputs(keyVals map[string][]string) Inputs { input[key] = Input{Array: &anyVals} } } - return input + return input, nil } func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { @@ -86,6 +135,8 @@ func (inputs *Inputs) toMap() (map[string]any, error) { } } keyVals[key] = dataURLs + case input.ChatMessage != nil: + keyVals[key] = *input.ChatMessage } } return keyVals, nil diff --git a/pkg/predict/input_test.go b/pkg/predict/input_test.go new file mode 100644 index 0000000000..0b9c536a26 --- /dev/null +++ b/pkg/predict/input_test.go @@ -0,0 +1,38 @@ +package predict + +import ( + "encoding/json" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +func TestNewInputsChatMessage(t *testing.T) { + chatMessage := `[{"role": "user", "content": "hello"}]` + key := "Key" + expected := json.RawMessage(chatMessage) + keyVals := map[string][]string{ + key: {chatMessage}, + } + openapiBody := `{"components":{"schemas":{"CommonChatSchemaAssistantMessage":{"properties":{"audio":{"$ref":"#/components/schemas/CommonChatSchemaAudio"},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaRefuslaContentPart"}]},"type":"array"}],"title":"Content"},"function_call":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"},"tool_calls":{"items":{"$ref":"#/components/schemas/CommonChatSchemaToolCall"},"title":"Tool Calls","type":"array"}},"title":"CommonChatSchemaAssistantMessage","type":"object"},"CommonChatSchemaAudio":{"properties":{"id":{"title":"Id","type":"string"}},"required":["id"],"title":"CommonChatSchemaAudio","type":"object"},"CommonChatSchemaAudioContentPart":{"properties":{"input_audio":{"$ref":"#/components/schemas/CommonChatSchemaInputAudio"},"type":{"title":"Type","type":"string"}},"required":["type","input_audio"],"title":"CommonChatSchemaAudioContentPart","type":"object"},"CommonChatSchemaDeveloperMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaDeveloperMessage","type":"object"},"CommonChatSchemaFunction":{"properties":{"arguments":{"title":"Arguments","type":"string"},"name":{"title":"Name","type":"string"}},"required":["name","arguments"],"title":"CommonChatSchemaFunction","type":"object"},"CommonChatSchemaFunctionMessage":{"properties":{"content":{"title":"Content","type":"string"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaFunctionMessage","type":"object"},"CommonChatSchemaImageContentPart":{"properties":{"image_url":{"$ref":"#/components/schemas/CommonChatSchemaImageURL"},"type":{"title":"Type","type":"string"}},"required":["type","image_url"],"title":"CommonChatSchemaImageContentPart","type":"object"},"CommonChatSchemaImageURL":{"properties":{"detail":{"title":"Detail","type":"string"},"url":{"title":"Url","type":"string"}},"required":["url","detail"],"title":"CommonChatSchemaImageURL","type":"object"},"CommonChatSchemaInputAudio":{"properties":{"data":{"title":"Data","type":"string"},"format":{"title":"Format","type":"string"}},"required":["data","format"],"title":"CommonChatSchemaInputAudio","type":"object"},"CommonChatSchemaRefuslaContentPart":{"properties":{"refusal":{"title":"Refusal","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","refusal"],"title":"CommonChatSchemaRefuslaContentPart","type":"object"},"CommonChatSchemaSystemMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaSystemMessage","type":"object"},"CommonChatSchemaTextContentPart":{"properties":{"text":{"title":"Text","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","text"],"title":"CommonChatSchemaTextContentPart","type":"object"},"CommonChatSchemaToolCall":{"properties":{"function":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"id":{"title":"Id","type":"string"},"type":{"title":"Type","type":"string"}},"required":["id","type","function"],"title":"CommonChatSchemaToolCall","type":"object"},"CommonChatSchemaToolMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"role":{"title":"Role","type":"string"},"tool_call_id":{"title":"Tool Call Id","type":"string"}},"required":["role","content","tool_call_id"],"title":"CommonChatSchemaToolMessage","type":"object"},"CommonChatSchemaUserMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaImageContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaAudioContentPart"}]},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaUserMessage","type":"object"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail","type":"array"}},"title":"HTTPValidationError","type":"object"},"Input":{"properties":{"messages":{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaDeveloperMessage"},{"$ref":"#/components/schemas/CommonChatSchemaSystemMessage"},{"$ref":"#/components/schemas/CommonChatSchemaUserMessage"},{"$ref":"#/components/schemas/CommonChatSchemaAssistantMessage"},{"$ref":"#/components/schemas/CommonChatSchemaToolMessage"},{"$ref":"#/components/schemas/CommonChatSchemaFunctionMessage"}]},"title":"Messages","type":"array","x-order":0}},"required":["messages"],"title":"Input","type":"object"},"Output":{"title":"Output","type":"string"},"PredictionRequest":{"properties":{"created_at":{"format":"date-time","title":"Created At","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"output_file_prefix":{"title":"Output File Prefix","type":"string"},"webhook":{"format":"uri","maxLength":65536,"minLength":1,"title":"Webhook","type":"string"},"webhook_events_filter":{"default":["start","output","logs","completed"],"items":{"$ref":"#/components/schemas/WebhookEvent"},"type":"array"}},"title":"PredictionRequest","type":"object"},"PredictionResponse":{"properties":{"completed_at":{"format":"date-time","title":"Completed At","type":"string"},"created_at":{"format":"date-time","title":"Created At","type":"string"},"error":{"title":"Error","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"logs":{"default":"","title":"Logs","type":"string"},"metrics":{"title":"Metrics","type":"object"},"output":{"$ref":"#/components/schemas/Output"},"started_at":{"format":"date-time","title":"Started At","type":"string"},"status":{"$ref":"#/components/schemas/Status"},"version":{"title":"Version","type":"string"}},"title":"PredictionResponse","type":"object"},"Status":{"description":"An enumeration.","enum":["starting","processing","succeeded","canceled","failed"],"title":"Status","type":"string"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location","type":"array"},"msg":{"title":"Message","type":"string"},"type":{"title":"Error Type","type":"string"}},"required":["loc","msg","type"],"title":"ValidationError","type":"object"},"WebhookEvent":{"description":"An enumeration.","enum":["start","output","logs","completed"],"title":"WebhookEvent","type":"string"}}},"info":{"title":"Cog","version":"0.1.0"},"openapi":"3.0.2","paths":{"/":{"get":{"operationId":"root__get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Root Get"}}},"description":"Successful Response"}},"summary":"Root"}},"/health-check":{"get":{"operationId":"healthcheck_health_check_get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Healthcheck Health Check Get"}}},"description":"Successful Response"}},"summary":"Healthcheck"}},"/predictions":{"post":{"description":"Run a single prediction on the model","operationId":"predict_predictions_post","parameters":[{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict"}},"/predictions/{prediction_id}":{"put":{"description":"Run a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}},{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction Request"}}},"required":true},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict Idempotent"}},"/predictions/{prediction_id}/cancel":{"post":{"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}}],"responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Cancel"}},"/shutdown":{"post":{"operationId":"start_shutdown_shutdown_post","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"summary":"Start Shutdown"}}}}` + schema, err := openapi3.NewLoader().LoadFromData([]byte(openapiBody)) + require.NoError(t, err) + inputs, err := NewInputs(keyVals, schema) + require.NoError(t, err) + require.Equal(t, expected, *inputs[key].ChatMessage) +} + +func TestNewInputsWithoutMessages(t *testing.T) { + expected := "world" + key := "s" + keyVals := map[string][]string{ + key: {expected}, + } + openapiBody := `{"components":{"schemas":{"CommonChatSchemaAssistantMessage":{"properties":{"audio":{"$ref":"#/components/schemas/CommonChatSchemaAudio"},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaRefuslaContentPart"}]},"type":"array"}],"title":"Content"},"function_call":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"},"tool_calls":{"items":{"$ref":"#/components/schemas/CommonChatSchemaToolCall"},"title":"Tool Calls","type":"array"}},"title":"CommonChatSchemaAssistantMessage","type":"object"},"CommonChatSchemaAudio":{"properties":{"id":{"title":"Id","type":"string"}},"required":["id"],"title":"CommonChatSchemaAudio","type":"object"},"CommonChatSchemaAudioContentPart":{"properties":{"input_audio":{"$ref":"#/components/schemas/CommonChatSchemaInputAudio"},"type":{"title":"Type","type":"string"}},"required":["type","input_audio"],"title":"CommonChatSchemaAudioContentPart","type":"object"},"CommonChatSchemaDeveloperMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaDeveloperMessage","type":"object"},"CommonChatSchemaFunction":{"properties":{"arguments":{"title":"Arguments","type":"string"},"name":{"title":"Name","type":"string"}},"required":["name","arguments"],"title":"CommonChatSchemaFunction","type":"object"},"CommonChatSchemaFunctionMessage":{"properties":{"content":{"title":"Content","type":"string"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaFunctionMessage","type":"object"},"CommonChatSchemaImageContentPart":{"properties":{"image_url":{"$ref":"#/components/schemas/CommonChatSchemaImageURL"},"type":{"title":"Type","type":"string"}},"required":["type","image_url"],"title":"CommonChatSchemaImageContentPart","type":"object"},"CommonChatSchemaImageURL":{"properties":{"detail":{"title":"Detail","type":"string"},"url":{"title":"Url","type":"string"}},"required":["url","detail"],"title":"CommonChatSchemaImageURL","type":"object"},"CommonChatSchemaInputAudio":{"properties":{"data":{"title":"Data","type":"string"},"format":{"title":"Format","type":"string"}},"required":["data","format"],"title":"CommonChatSchemaInputAudio","type":"object"},"CommonChatSchemaRefuslaContentPart":{"properties":{"refusal":{"title":"Refusal","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","refusal"],"title":"CommonChatSchemaRefuslaContentPart","type":"object"},"CommonChatSchemaSystemMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaSystemMessage","type":"object"},"CommonChatSchemaTextContentPart":{"properties":{"text":{"title":"Text","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","text"],"title":"CommonChatSchemaTextContentPart","type":"object"},"CommonChatSchemaToolCall":{"properties":{"function":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"id":{"title":"Id","type":"string"},"type":{"title":"Type","type":"string"}},"required":["id","type","function"],"title":"CommonChatSchemaToolCall","type":"object"},"CommonChatSchemaToolMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"role":{"title":"Role","type":"string"},"tool_call_id":{"title":"Tool Call Id","type":"string"}},"required":["role","content","tool_call_id"],"title":"CommonChatSchemaToolMessage","type":"object"},"CommonChatSchemaUserMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaImageContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaAudioContentPart"}]},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaUserMessage","type":"object"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail","type":"array"}},"title":"HTTPValidationError","type":"object"},"Input":{"properties":{},"required":[],"title":"Input","type":"object"},"Output":{"title":"Output","type":"string"},"PredictionRequest":{"properties":{"created_at":{"format":"date-time","title":"Created At","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"output_file_prefix":{"title":"Output File Prefix","type":"string"},"webhook":{"format":"uri","maxLength":65536,"minLength":1,"title":"Webhook","type":"string"},"webhook_events_filter":{"default":["start","output","logs","completed"],"items":{"$ref":"#/components/schemas/WebhookEvent"},"type":"array"}},"title":"PredictionRequest","type":"object"},"PredictionResponse":{"properties":{"completed_at":{"format":"date-time","title":"Completed At","type":"string"},"created_at":{"format":"date-time","title":"Created At","type":"string"},"error":{"title":"Error","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"logs":{"default":"","title":"Logs","type":"string"},"metrics":{"title":"Metrics","type":"object"},"output":{"$ref":"#/components/schemas/Output"},"started_at":{"format":"date-time","title":"Started At","type":"string"},"status":{"$ref":"#/components/schemas/Status"},"version":{"title":"Version","type":"string"}},"title":"PredictionResponse","type":"object"},"Status":{"description":"An enumeration.","enum":["starting","processing","succeeded","canceled","failed"],"title":"Status","type":"string"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location","type":"array"},"msg":{"title":"Message","type":"string"},"type":{"title":"Error Type","type":"string"}},"required":["loc","msg","type"],"title":"ValidationError","type":"object"},"WebhookEvent":{"description":"An enumeration.","enum":["start","output","logs","completed"],"title":"WebhookEvent","type":"string"}}},"info":{"title":"Cog","version":"0.1.0"},"openapi":"3.0.2","paths":{"/":{"get":{"operationId":"root__get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Root Get"}}},"description":"Successful Response"}},"summary":"Root"}},"/health-check":{"get":{"operationId":"healthcheck_health_check_get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Healthcheck Health Check Get"}}},"description":"Successful Response"}},"summary":"Healthcheck"}},"/predictions":{"post":{"description":"Run a single prediction on the model","operationId":"predict_predictions_post","parameters":[{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict"}},"/predictions/{prediction_id}":{"put":{"description":"Run a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}},{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction Request"}}},"required":true},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict Idempotent"}},"/predictions/{prediction_id}/cancel":{"post":{"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}}],"responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Cancel"}},"/shutdown":{"post":{"operationId":"start_shutdown_shutdown_post","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"summary":"Start Shutdown"}}}}` + schema, err := openapi3.NewLoader().LoadFromData([]byte(openapiBody)) + require.NoError(t, err) + inputs, err := NewInputs(keyVals, schema) + require.NoError(t, err) + require.Equal(t, expected, *inputs[key].String) +} diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 72f1399cd0..7e8403d522 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -7,6 +7,7 @@ from .server.scope import current_scope, emit_metric from .types import ( AsyncConcatenateIterator, + ChatMessage, ConcatenateIterator, ExperimentalFeatureWarning, File, @@ -30,6 +31,7 @@ "AsyncConcatenateIterator", "BaseModel", "BasePredictor", + "ChatMessage", "ConcatenateIterator", "ExperimentalFeatureWarning", "File", diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a9b32f7553..3c60f4604c 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -35,6 +35,7 @@ from .code_xforms import load_module_from_string, strip_model_source_code from .types import ( PYDANTIC_V2, + ChatMessage, Input, Weights, ) @@ -48,11 +49,12 @@ log = structlog.get_logger("cog.server.predictor") -ALLOWED_INPUT_TYPES: List[Type[Any]] = [ +ALLOWED_INPUT_TYPES: List[Union[Type[Any], Type[ChatMessage]]] = [ str, int, float, bool, + ChatMessage, CogFile, CogPath, CogSecret, diff --git a/python/cog/types.py b/python/cog/types.py index 8b0de604cd..15bc98c95b 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -64,6 +64,114 @@ class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many- max: NotRequired[int] +class CommonChatSchemaTextContentPart(TypedDict): + type: str + text: str + + +class CommonChatSchemaImageURL(TypedDict): + url: str + detail: Optional[str] + + +class CommonChatSchemaImageContentPart(TypedDict): + type: str + image_url: CommonChatSchemaImageURL + + +class CommonChatSchemaInputAudio(TypedDict): + data: str + format: str + + +class CommonChatSchemaAudioContentPart(TypedDict): + type: str + input_audio: CommonChatSchemaInputAudio + + +class CommonChatSchemaRefuslaContentPart(TypedDict): + type: str + refusal: str + + +class CommonChatSchemaAudio(TypedDict): + id: str + + +class CommonChatSchemaFunction(TypedDict): + name: str + arguments: str + + +class CommonChatSchemaToolCall(TypedDict): + id: str + type: str + function: CommonChatSchemaFunction + + +class CommonChatSchemaDeveloperMessage(TypedDict, total=False): + content: Union[str, List[str]] + role: str + name: Optional[str] + + +class CommonChatSchemaSystemMessage(TypedDict, total=False): + content: Union[str, List[str]] + role: str + name: Optional[str] + + +class CommonChatSchemaUserMessage(TypedDict, total=False): + content: Union[ + str, + List[ + Union[ + CommonChatSchemaTextContentPart, + CommonChatSchemaImageContentPart, + CommonChatSchemaAudioContentPart, + ] + ], + ] + role: str + name: Optional[str] + + +class CommonChatSchemaAssistantMessage(TypedDict, total=False): + content: Union[ + str, + List[ + Union[CommonChatSchemaTextContentPart, CommonChatSchemaRefuslaContentPart] + ], + ] + role: str + name: Optional[str] + audio: Optional[CommonChatSchemaAudio] + tool_calls: Optional[List[CommonChatSchemaToolCall]] + function_call: Optional[CommonChatSchemaFunction] + + +class CommonChatSchemaToolMessage(TypedDict): + role: str + content: Union[str, List[str]] + tool_call_id: str + + +class CommonChatSchemaFunctionMessage(TypedDict, total=False): + role: str + content: Optional[str] + name: str + + +ChatMessage = Union[ + CommonChatSchemaDeveloperMessage, + CommonChatSchemaSystemMessage, + CommonChatSchemaUserMessage, + CommonChatSchemaAssistantMessage, + CommonChatSchemaToolMessage, + CommonChatSchemaFunctionMessage, +] + + def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., description: Optional[str] = None, diff --git a/test-integration/test_integration/fixtures/chat-message-project/cog.yaml b/test-integration/test_integration/fixtures/chat-message-project/cog.yaml new file mode 100644 index 0000000000..e357cab833 --- /dev/null +++ b/test-integration/test_integration/fixtures/chat-message-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.9" +predict: predict.py:Predictor diff --git a/test-integration/test_integration/fixtures/chat-message-project/predict.py b/test-integration/test_integration/fixtures/chat-message-project/predict.py new file mode 100644 index 0000000000..371edbdda6 --- /dev/null +++ b/test-integration/test_integration/fixtures/chat-message-project/predict.py @@ -0,0 +1,8 @@ +from cog import BasePredictor, ChatMessage + + +class Predictor(BasePredictor): + + def predict(self, messages: list[ChatMessage]) -> str: + print(messages) + return f"HELLO {messages[0]['role']}" diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 459f09f03e..26e90751dc 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -7,6 +7,7 @@ import httpx import pytest +import requests from .util import cog_server_http_run @@ -369,3 +370,17 @@ async def make_request(i: int) -> httpx.Response: for i, task in enumerate(tasks): assert task.result().status_code == 200 assert task.result().json()["output"] == f"wake up sleepyhead{i}" + + +def test_predict_chat_message(): + with cog_server_http_run( + Path(__file__).parent / "fixtures" / "chat-message-project" + ) as addr: + response = requests.post( + addr + "/predictions", + json={"input": {"messages": [{"role": "User", "content": "Hello There!"}]}}, + timeout=3.0, + ) + response.raise_for_status() + body = response.json() + assert body["output"] == "HELLO User" diff --git a/tox.ini b/tox.ini index 848b138b83..f32787ed6a 100644 --- a/tox.ini +++ b/tox.ini @@ -71,4 +71,5 @@ deps = pytest-rerunfailures pytest-timeout pytest-xdist + requests commands = pytest {posargs:-n auto -vv --reruns 3}