Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChatMessage as a first class cog type #2186

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
return err
}

inputs, err := parseInputFlags(inputFlags)
inputs, err := parseInputFlags(inputFlags, schema)
if err != nil {
return err
}
Expand Down Expand Up @@ -361,7 +361,7 @@ func writeDataURLOutput(outputString string, outputPath string, addExtension boo
return nil
}

func parseInputFlags(inputs []string) (predict.Inputs, error) {
func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error) {
keyVals := map[string][]string{}
for _, input := range inputs {
var name, value string
Expand All @@ -383,7 +383,7 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) {
keyVals[name] = append(keyVals[name], value)
}

return predict.NewInputs(keyVals), nil
return predict.NewInputs(keyVals, schema)
}

func addSetupTimeoutFlag(cmd *cobra.Command) {
Expand Down
65 changes: 58 additions & 7 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,83 @@
package predict

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/getkin/kin-openapi/openapi3"
"github.com/mitchellh/go-homedir"
"github.com/vincent-petithory/dataurl"

"github.com/replicate/cog/pkg/util/mime"
)

type Input struct {
String *string
File *string
Array *[]any
String *string
File *string
Array *[]any
ChatMessage *json.RawMessage
}

type Inputs map[string]Input

func NewInputs(keyVals map[string][]string) Inputs {
var jsonSerializableSchemas = map[string]bool{
"#/components/schemas/CommonChatSchemaDeveloperMessage": true,
"#/components/schemas/CommonChatSchemaSystemMessage": true,
"#/components/schemas/CommonChatSchemaUserMessage": true,
"#/components/schemas/CommonChatSchemaAssistantMessage": true,
"#/components/schemas/CommonChatSchemaToolMessage": true,
"#/components/schemas/CommonChatSchemaFunctionMessage": true,
}

func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) {
var inputComponent *openapi3.SchemaRef
for name, component := range schema.Components.Schemas {
if name == "Input" {
inputComponent = component
break
}
}

input := Inputs{}
for key, vals := range keyVals {
if len(vals) == 1 {
val := vals[0]
if strings.HasPrefix(val, "@") {

// Check if we should explicitly parse the JSON based on a known schema
if inputComponent != nil {
properties, err := inputComponent.JSONLookup("properties")
if err != nil {
return input, err
}
propertiesSchemas := properties.(openapi3.Schemas)
messages, err := propertiesSchemas.JSONLookup("messages")
// If there is an error it means messages was not found, this is valid for an OpenAPI schema.
if err == nil {
messagesSchemas := messages.(*openapi3.Schema)
found := false
for _, schemaRef := range messagesSchemas.Items.Value.AnyOf {
if _, ok := jsonSerializableSchemas[schemaRef.Ref]; ok {
found = true
message := json.RawMessage(val)
input[key] = Input{ChatMessage: &message}
break
}
}
if found {
continue
}
}
}

switch {
case strings.HasPrefix(val, "@"):
val = val[1:]
input[key] = Input{File: &val}
} else {

default:
input[key] = Input{String: &val}
}
} else if len(vals) > 1 {
Expand All @@ -39,7 +88,7 @@ func NewInputs(keyVals map[string][]string) Inputs {
input[key] = Input{Array: &anyVals}
}
}
return input
return input, nil
}

func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
Expand Down Expand Up @@ -86,6 +135,8 @@ func (inputs *Inputs) toMap() (map[string]any, error) {
}
}
keyVals[key] = dataURLs
case input.ChatMessage != nil:
keyVals[key] = *input.ChatMessage
}
}
return keyVals, nil
Expand Down
38 changes: 38 additions & 0 deletions pkg/predict/input_test.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .server.scope import current_scope, emit_metric
from .types import (
AsyncConcatenateIterator,
ChatMessage,
ConcatenateIterator,
ExperimentalFeatureWarning,
File,
Expand All @@ -30,6 +31,7 @@
"AsyncConcatenateIterator",
"BaseModel",
"BasePredictor",
"ChatMessage",
"ConcatenateIterator",
"ExperimentalFeatureWarning",
"File",
Expand Down
4 changes: 3 additions & 1 deletion python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .code_xforms import load_module_from_string, strip_model_source_code
from .types import (
PYDANTIC_V2,
ChatMessage,
Input,
Weights,
)
Expand All @@ -48,11 +49,12 @@

log = structlog.get_logger("cog.server.predictor")

ALLOWED_INPUT_TYPES: List[Type[Any]] = [
ALLOWED_INPUT_TYPES: List[Union[Type[Any], Type[ChatMessage]]] = [
str,
int,
float,
bool,
ChatMessage,
CogFile,
CogPath,
CogSecret,
Expand Down
108 changes: 108 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,114 @@ class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many-
max: NotRequired[int]


class CommonChatSchemaTextContentPart(TypedDict):
type: str
text: str


class CommonChatSchemaImageURL(TypedDict):
url: str
detail: Optional[str]


class CommonChatSchemaImageContentPart(TypedDict):
type: str
image_url: CommonChatSchemaImageURL


class CommonChatSchemaInputAudio(TypedDict):
data: str
format: str


class CommonChatSchemaAudioContentPart(TypedDict):
type: str
input_audio: CommonChatSchemaInputAudio


class CommonChatSchemaRefuslaContentPart(TypedDict):
type: str
refusal: str


class CommonChatSchemaAudio(TypedDict):
id: str


class CommonChatSchemaFunction(TypedDict):
name: str
arguments: str


class CommonChatSchemaToolCall(TypedDict):
id: str
type: str
function: CommonChatSchemaFunction


class CommonChatSchemaDeveloperMessage(TypedDict, total=False):
content: Union[str, List[str]]
role: str
name: Optional[str]


class CommonChatSchemaSystemMessage(TypedDict, total=False):
content: Union[str, List[str]]
role: str
name: Optional[str]


class CommonChatSchemaUserMessage(TypedDict, total=False):
content: Union[
str,
List[
Union[
CommonChatSchemaTextContentPart,
CommonChatSchemaImageContentPart,
CommonChatSchemaAudioContentPart,
]
],
]
role: str
name: Optional[str]


class CommonChatSchemaAssistantMessage(TypedDict, total=False):
content: Union[
str,
List[
Union[CommonChatSchemaTextContentPart, CommonChatSchemaRefuslaContentPart]
],
]
role: str
name: Optional[str]
audio: Optional[CommonChatSchemaAudio]
tool_calls: Optional[List[CommonChatSchemaToolCall]]
function_call: Optional[CommonChatSchemaFunction]


class CommonChatSchemaToolMessage(TypedDict):
role: str
content: Union[str, List[str]]
tool_call_id: str


class CommonChatSchemaFunctionMessage(TypedDict, total=False):
role: str
content: Optional[str]
name: str


ChatMessage = Union[
CommonChatSchemaDeveloperMessage,
CommonChatSchemaSystemMessage,
CommonChatSchemaUserMessage,
CommonChatSchemaAssistantMessage,
CommonChatSchemaToolMessage,
CommonChatSchemaFunctionMessage,
]


def Input( # pylint: disable=invalid-name, too-many-arguments
default: Any = ...,
description: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.9"
predict: predict.py:Predictor
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cog import BasePredictor, ChatMessage


class Predictor(BasePredictor):

def predict(self, messages: list[ChatMessage]) -> str:
print(messages)
return f"HELLO {messages[0]['role']}"
15 changes: 15 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import httpx
import pytest
import requests

from .util import cog_server_http_run

Expand Down Expand Up @@ -369,3 +370,17 @@ async def make_request(i: int) -> httpx.Response:
for i, task in enumerate(tasks):
assert task.result().status_code == 200
assert task.result().json()["output"] == f"wake up sleepyhead{i}"


def test_predict_chat_message():
with cog_server_http_run(
Path(__file__).parent / "fixtures" / "chat-message-project"
) as addr:
response = requests.post(
addr + "/predictions",
json={"input": {"messages": [{"role": "User", "content": "Hello There!"}]}},
timeout=3.0,
)
response.raise_for_status()
body = response.json()
assert body["output"] == "HELLO User"
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ deps =
pytest-rerunfailures
pytest-timeout
pytest-xdist
requests
commands = pytest {posargs:-n auto -vv --reruns 3}
Loading