From 21764a711affabffc3a8631f10ac53c3df9344d2 Mon Sep 17 00:00:00 2001 From: Gerardo Lecaros <10088504+glecaros@users.noreply.github.com> Date: Thu, 3 Oct 2024 22:14:20 -0700 Subject: [PATCH] Fixing serialization of all models that have defaults. (#20) --- python/pyproject.toml | 2 +- python/rtclient/models.py | 26 ++-- python/rtclient/util/model_helpers.py | 10 +- python/rtclient/util/model_helpers_test.py | 26 ++-- python/samples/download-wheel.ps1 | 2 +- python/samples/download-wheel.sh | 2 +- python/samples/sample_text_input.py | 172 +++++++++++++++++++++ 7 files changed, 210 insertions(+), 30 deletions(-) create mode 100644 python/samples/sample_text_input.py diff --git a/python/pyproject.toml b/python/pyproject.toml index d8163c7..67672a6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rtclient" -version = "0.4.3" +version = "0.4.4" description = "A client for the RT API" authors = ["Microsoft Corporation"] diff --git a/python/rtclient/models.py b/python/rtclient/models.py index 3a5f19d..9e35ff0 100644 --- a/python/rtclient/models.py +++ b/python/rtclient/models.py @@ -12,18 +12,18 @@ model_serializer, ) -from rtclient.util.model_helpers import ModelWithType +from rtclient.util.model_helpers import ModelWithDefaults Voice = Literal["alloy", "shimmer", "echo"] AudioFormat = Literal["pcm16", "g711-ulaw", "g711-alaw"] Modality = Literal["text", "audio"] -class NoTurnDetection(ModelWithType): +class NoTurnDetection(ModelWithDefaults): type: Literal["none"] = "none" -class ServerVAD(ModelWithType): +class ServerVAD(ModelWithDefaults): type: Literal["server_vad"] = "server_vad" threshold: Optional[Annotated[float, Field(strict=True, ge=0.0, le=1.0)]] = None prefix_padding_ms: Optional[int] = None @@ -33,7 +33,7 @@ class ServerVAD(ModelWithType): TurnDetection = Annotated[Union[NoTurnDetection, ServerVAD], Field(discriminator="type")] -class FunctionToolChoice(ModelWithType): +class FunctionToolChoice(ModelWithDefaults): type: Literal["function"] = "function" function: str @@ -47,7 +47,7 @@ class InputAudioTranscription(BaseModel): model: Literal["whisper-1"] -class ClientMessageBase(ModelWithType): +class ClientMessageBase(ModelWithDefaults): _is_azure: bool = False event_id: Optional[str] = None @@ -125,18 +125,18 @@ class InputAudioBufferClearMessage(ClientMessageBase): MessageItemType = Literal["message"] -class InputTextContentPart(ModelWithType): +class InputTextContentPart(ModelWithDefaults): type: Literal["input_text"] = "input_text" text: str -class InputAudioContentPart(ModelWithType): +class InputAudioContentPart(ModelWithDefaults): type: Literal["input_audio"] = "input_audio" audio: str transcript: Optional[str] = None -class OutputTextContentPart(ModelWithType): +class OutputTextContentPart(ModelWithDefaults): type: Literal["text"] = "text" text: str @@ -148,7 +148,7 @@ class OutputTextContentPart(ModelWithType): ItemParamStatus = Literal["completed", "incomplete"] -class SystemMessageItem(BaseModel): +class SystemMessageItem(ModelWithDefaults): type: MessageItemType = "message" role: Literal["system"] = "system" id: Optional[str] = None @@ -156,7 +156,7 @@ class SystemMessageItem(BaseModel): status: Optional[ItemParamStatus] = None -class UserMessageItem(BaseModel): +class UserMessageItem(ModelWithDefaults): type: MessageItemType = "message" role: Literal["user"] = "user" id: Optional[str] = None @@ -164,7 +164,7 @@ class UserMessageItem(BaseModel): status: Optional[ItemParamStatus] = None -class AssistantMessageItem(BaseModel): +class AssistantMessageItem(ModelWithDefaults): type: MessageItemType = "message" role: Literal["assistant"] = "assistant" id: Optional[str] = None @@ -175,7 +175,7 @@ class AssistantMessageItem(BaseModel): MessageItem = Annotated[Union[SystemMessageItem, UserMessageItem, AssistantMessageItem], Field(discriminator="role")] -class FunctionCallItem(ModelWithType): +class FunctionCallItem(ModelWithDefaults): type: Literal["function_call"] = "function_call" id: Optional[str] = None name: str @@ -184,7 +184,7 @@ class FunctionCallItem(ModelWithType): status: Optional[ItemParamStatus] = None -class FunctionCallOutputItem(ModelWithType): +class FunctionCallOutputItem(ModelWithDefaults): type: Literal["function_call_output"] = "function_call_output" id: Optional[str] = None call_id: str diff --git a/python/rtclient/util/model_helpers.py b/python/rtclient/util/model_helpers.py index 585b4ef..679b66f 100644 --- a/python/rtclient/util/model_helpers.py +++ b/python/rtclient/util/model_helpers.py @@ -4,9 +4,11 @@ from pydantic import BaseModel, model_validator -class ModelWithType(BaseModel): +class ModelWithDefaults(BaseModel): @model_validator(mode="after") - def _add_type(self): - if "type" in self.model_fields: - self.type = self.model_fields["type"].default + def _add_defaults(self): + for field in self.model_fields: + if self.model_fields[field].default is not None: + if not hasattr(self, field) or getattr(self, field) == self.model_fields[field].default: + setattr(self, field, self.model_fields[field].default) return self diff --git a/python/rtclient/util/model_helpers_test.py b/python/rtclient/util/model_helpers_test.py index 18fb7d5..d09f410 100644 --- a/python/rtclient/util/model_helpers_test.py +++ b/python/rtclient/util/model_helpers_test.py @@ -1,17 +1,23 @@ -from typing import Literal +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. -from model_helpers import ModelWithType +from typing import Optional +from model_helpers import ModelWithDefaults -class ModelWithType(ModelWithType): - type: Literal["object_type"] = "object_type" +class Bar(ModelWithDefaults): + foo: Optional[int] = None + bar: Optional[float] = 3.14 + baz: int = 42 -def test_with_type_field(): - instance = ModelWithType() - assert instance.type == "object_type" +def test_with_defaults(): + instance = Bar() + assert instance.foo is None + assert instance.baz == 42 -def test_serialize_with_type_field(): - instance = ModelWithType() - assert instance.model_dump() == {"type": "object_type"} + +def test_serialize_with_defaults(): + instance = Bar() + assert instance.model_dump(exclude_unset=True) == {"bar": 3.14, "baz": 42} diff --git a/python/samples/download-wheel.ps1 b/python/samples/download-wheel.ps1 index 5f07d64..eca06c7 100755 --- a/python/samples/download-wheel.ps1 +++ b/python/samples/download-wheel.ps1 @@ -1,7 +1,7 @@ $Owner = "Azure-Samples" $Repo = "aoai-realtime-audio-sdk" $Filter = "-py3-none-any.whl" -$Release = "py/v0.4.3" +$Release = "py/v0.4.4" $OutputDir = "." $apiUrl = "https://api.github.com/repos/$Owner/$Repo/releases/tags/$Release" diff --git a/python/samples/download-wheel.sh b/python/samples/download-wheel.sh index ed28617..1953e3e 100755 --- a/python/samples/download-wheel.sh +++ b/python/samples/download-wheel.sh @@ -3,7 +3,7 @@ OWNER="Azure-Samples" REPO="aoai-realtime-audio-sdk" FILTER="-py3-none-any.whl" -RELEASE="py/v0.4.3" +RELEASE="py/v0.4.4" OUTPUT_DIR="." API_URL="https://api.github.com/repos/$OWNER/$REPO/releases/tags/$RELEASE" diff --git a/python/samples/sample_text_input.py b/python/samples/sample_text_input.py new file mode 100644 index 0000000..5c5ac82 --- /dev/null +++ b/python/samples/sample_text_input.py @@ -0,0 +1,172 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import base64 +import os +import sys +import time + +import numpy as np +import soundfile as sf +from azure.core.credentials import AzureKeyCredential +from dotenv import load_dotenv + +from rtclient import InputTextContentPart, RTClient, RTInputItem, RTOutputItem, RTResponse, UserMessageItem + +start_time = time.time() + + +def log(message): + elapsed_time_ms = int((time.time() - start_time) * 1000) + log(f"{elapsed_time_ms} [ms]: {message}", flush=True) + + +async def receive_control(client: RTClient): + async for control in client.control_messages(): + if control is not None: + log(f"Received a control message: {control.type}") + else: + break + + +async def receive_item(item: RTOutputItem, out_dir: str): + prefix = f"[response={item.response_id}][item={item.id}]" + audio_data = None + audio_transcript = None + text_data = None + arguments = None + async for chunk in item: + if chunk.type == "audio_transcript": + audio_transcript = (audio_transcript or "") + chunk.data + elif chunk.type == "audio": + if audio_data is None: + audio_data = bytearray() + audio_bytes = base64.b64decode(chunk.data) + audio_data.extend(audio_bytes) + elif chunk.type == "tool_call_arguments": + arguments = (arguments or "") + chunk.data + elif chunk.type == "text": + text_data = (text_data or "") + chunk.data + if text_data is not None: + log(prefix, f"Text: {text_data}") + with open(os.path.join(out_dir, f"{item.id}.text.txt"), "w", encoding="utf-8") as out: + out.write(text_data) + if audio_data is not None: + log(prefix, f"Audio received with length: {len(audio_data)}") + with open(os.path.join(out_dir, f"{item.id}.wav"), "wb") as out: + audio_array = np.frombuffer(audio_data, dtype=np.int16) + sf.write(out, audio_array, samplerate=24000) + if audio_transcript is not None: + log(prefix, f"Audio Transcript: {audio_transcript}") + with open(os.path.join(out_dir, f"{item.id}.audio_transcript.txt"), "w", encoding="utf-8") as out: + out.write(audio_transcript) + if arguments is not None: + log(prefix, f"Tool Call Arguments: {arguments}") + with open(os.path.join(out_dir, f"{item.id}.tool.streamed.json"), "w", encoding="utf-8") as out: + out.write(arguments) + + +async def receive_response(client: RTClient, response: RTResponse, out_dir: str): + prefix = f"[response={response.id}]" + async for item in response: + log(prefix, f"Received item {item.id}") + asyncio.create_task(receive_item(item, out_dir)) + log(prefix, "Response completed") + await client.close() + + +async def receive_input_item(item: RTInputItem): + prefix = f"[input_item={item.id}]" + await item + log(prefix, f"Previous Id: {item.previous_id}") + log(prefix, f"Transcript: {item.transcript}") + log(prefix, f"Audio Start [ms]: {item.audio_start_ms}") + log(prefix, f"Audio End [ms]: {item.audio_end_ms}") + + +async def receive_items(client: RTClient, out_dir: str): + async for item in client.items(): + if isinstance(item, RTResponse): + asyncio.create_task(receive_response(client, item, out_dir)) + else: + asyncio.create_task(receive_input_item(item)) + + +async def receive_messages(client: RTClient, out_dir: str): + await asyncio.gather( + receive_items(client, out_dir), + receive_control(client), + ) + + +async def run(client: RTClient, instructions_file_path: str, user_message_file_path: str, out_dir: str): + with open(instructions_file_path) as instructions_file, open(user_message_file_path) as user_message_file: + instructions = instructions_file.read() + user_message = user_message_file.read() + log("Configuring Session...") + await client.configure( + instructions=instructions, + ) + log("Done") + log("Sending User Message...") + await client.send_item(UserMessageItem(content=[InputTextContentPart(text=user_message)])) + log("Done") + await client.generate_response() + await receive_messages(client, out_dir) + + +def get_env_var(var_name: str) -> str: + value = os.environ.get(var_name) + if not value: + raise OSError(f"Environment variable '{var_name}' is not set or is empty.") + return value + + +async def with_azure_openai(instructions_file_path: str, user_message_file_path: str, out_dir: str): + endpoint = get_env_var("AZURE_OPENAI_ENDPOINT") + key = get_env_var("AZURE_OPENAI_API_KEY") + deployment = get_env_var("AZURE_OPENAI_DEPLOYMENT") + async with RTClient(url=endpoint, key_credential=AzureKeyCredential(key), azure_deployment=deployment) as client: + await run(client, instructions_file_path, user_message_file_path, out_dir) + + +async def with_openai(instructions_file_path: str, user_message_file_path: str, out_dir: str): + key = get_env_var("OPENAI_API_KEY") + model = get_env_var("OPENAI_MODEL") + async with RTClient(key_credential=AzureKeyCredential(key), model=model) as client: + await run(client, instructions_file_path, user_message_file_path, out_dir) + + +if __name__ == "__main__": + load_dotenv() + if len(sys.argv) < 3: + log(f"Usage: python {sys.argv[0]} [azure|openai]") + log("If the last argument is not provided, it will default to azure") + sys.exit(1) + + instructions_file_path = sys.argv[1] + user_message_file_path = sys.argv[2] + out_dir = sys.argv[3] + provider = sys.argv[4] if len(sys.argv) == 4 else "azure" + + if not os.path.isfile(instructions_file_path): + log(f"File {instructions_file_path} does not exist") + sys.exit(1) + + if not os.path.isfile(user_message_file_path): + log(f"File {user_message_file_path} does not exist") + sys.exit(1) + + if not os.path.isdir(out_dir): + log(f"Directory {out_dir} does not exist") + sys.exit(1) + + if provider not in ["azure", "openai"]: + log(f"Provider {provider} needs to be one of 'azure' or 'openai'") + sys.exit(1) + + if provider == "azure": + asyncio.run(with_azure_openai(instructions_file_path, user_message_file_path, out_dir)) + else: + asyncio.run(with_openai(instructions_file_path, user_message_file_path, out_dir))