Skip to content

Commit

Permalink
Fixing serialization of all models that have defaults. (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
glecaros authored Oct 4, 2024
1 parent 2ee1005 commit 21764a7
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 30 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "rtclient"
version = "0.4.3"
version = "0.4.4"
description = "A client for the RT API"
authors = ["Microsoft Corporation"]

Expand Down
26 changes: 13 additions & 13 deletions python/rtclient/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
model_serializer,
)

from rtclient.util.model_helpers import ModelWithType
from rtclient.util.model_helpers import ModelWithDefaults

Voice = Literal["alloy", "shimmer", "echo"]
AudioFormat = Literal["pcm16", "g711-ulaw", "g711-alaw"]
Modality = Literal["text", "audio"]


class NoTurnDetection(ModelWithType):
class NoTurnDetection(ModelWithDefaults):
type: Literal["none"] = "none"


class ServerVAD(ModelWithType):
class ServerVAD(ModelWithDefaults):
type: Literal["server_vad"] = "server_vad"
threshold: Optional[Annotated[float, Field(strict=True, ge=0.0, le=1.0)]] = None
prefix_padding_ms: Optional[int] = None
Expand All @@ -33,7 +33,7 @@ class ServerVAD(ModelWithType):
TurnDetection = Annotated[Union[NoTurnDetection, ServerVAD], Field(discriminator="type")]


class FunctionToolChoice(ModelWithType):
class FunctionToolChoice(ModelWithDefaults):
type: Literal["function"] = "function"
function: str

Expand All @@ -47,7 +47,7 @@ class InputAudioTranscription(BaseModel):
model: Literal["whisper-1"]


class ClientMessageBase(ModelWithType):
class ClientMessageBase(ModelWithDefaults):
_is_azure: bool = False
event_id: Optional[str] = None

Expand Down Expand Up @@ -125,18 +125,18 @@ class InputAudioBufferClearMessage(ClientMessageBase):
MessageItemType = Literal["message"]


class InputTextContentPart(ModelWithType):
class InputTextContentPart(ModelWithDefaults):
type: Literal["input_text"] = "input_text"
text: str


class InputAudioContentPart(ModelWithType):
class InputAudioContentPart(ModelWithDefaults):
type: Literal["input_audio"] = "input_audio"
audio: str
transcript: Optional[str] = None


class OutputTextContentPart(ModelWithType):
class OutputTextContentPart(ModelWithDefaults):
type: Literal["text"] = "text"
text: str

Expand All @@ -148,23 +148,23 @@ class OutputTextContentPart(ModelWithType):
ItemParamStatus = Literal["completed", "incomplete"]


class SystemMessageItem(BaseModel):
class SystemMessageItem(ModelWithDefaults):
type: MessageItemType = "message"
role: Literal["system"] = "system"
id: Optional[str] = None
content: list[SystemContentPart]
status: Optional[ItemParamStatus] = None


class UserMessageItem(BaseModel):
class UserMessageItem(ModelWithDefaults):
type: MessageItemType = "message"
role: Literal["user"] = "user"
id: Optional[str] = None
content: list[UserContentPart]
status: Optional[ItemParamStatus] = None


class AssistantMessageItem(BaseModel):
class AssistantMessageItem(ModelWithDefaults):
type: MessageItemType = "message"
role: Literal["assistant"] = "assistant"
id: Optional[str] = None
Expand All @@ -175,7 +175,7 @@ class AssistantMessageItem(BaseModel):
MessageItem = Annotated[Union[SystemMessageItem, UserMessageItem, AssistantMessageItem], Field(discriminator="role")]


class FunctionCallItem(ModelWithType):
class FunctionCallItem(ModelWithDefaults):
type: Literal["function_call"] = "function_call"
id: Optional[str] = None
name: str
Expand All @@ -184,7 +184,7 @@ class FunctionCallItem(ModelWithType):
status: Optional[ItemParamStatus] = None


class FunctionCallOutputItem(ModelWithType):
class FunctionCallOutputItem(ModelWithDefaults):
type: Literal["function_call_output"] = "function_call_output"
id: Optional[str] = None
call_id: str
Expand Down
10 changes: 6 additions & 4 deletions python/rtclient/util/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from pydantic import BaseModel, model_validator


class ModelWithType(BaseModel):
class ModelWithDefaults(BaseModel):
@model_validator(mode="after")
def _add_type(self):
if "type" in self.model_fields:
self.type = self.model_fields["type"].default
def _add_defaults(self):
for field in self.model_fields:
if self.model_fields[field].default is not None:
if not hasattr(self, field) or getattr(self, field) == self.model_fields[field].default:
setattr(self, field, self.model_fields[field].default)
return self
26 changes: 16 additions & 10 deletions python/rtclient/util/model_helpers_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from typing import Literal
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from model_helpers import ModelWithType
from typing import Optional

from model_helpers import ModelWithDefaults

class ModelWithType(ModelWithType):
type: Literal["object_type"] = "object_type"

class Bar(ModelWithDefaults):
foo: Optional[int] = None
bar: Optional[float] = 3.14
baz: int = 42

def test_with_type_field():
instance = ModelWithType()
assert instance.type == "object_type"

def test_with_defaults():
instance = Bar()
assert instance.foo is None
assert instance.baz == 42

def test_serialize_with_type_field():
instance = ModelWithType()
assert instance.model_dump() == {"type": "object_type"}

def test_serialize_with_defaults():
instance = Bar()
assert instance.model_dump(exclude_unset=True) == {"bar": 3.14, "baz": 42}
2 changes: 1 addition & 1 deletion python/samples/download-wheel.ps1
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
$Owner = "Azure-Samples"
$Repo = "aoai-realtime-audio-sdk"
$Filter = "-py3-none-any.whl"
$Release = "py/v0.4.3"
$Release = "py/v0.4.4"
$OutputDir = "."

$apiUrl = "https://api.github.com/repos/$Owner/$Repo/releases/tags/$Release"
Expand Down
2 changes: 1 addition & 1 deletion python/samples/download-wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
OWNER="Azure-Samples"
REPO="aoai-realtime-audio-sdk"
FILTER="-py3-none-any.whl"
RELEASE="py/v0.4.3"
RELEASE="py/v0.4.4"
OUTPUT_DIR="."

API_URL="https://api.github.com/repos/$OWNER/$REPO/releases/tags/$RELEASE"
Expand Down
172 changes: 172 additions & 0 deletions python/samples/sample_text_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import base64
import os
import sys
import time

import numpy as np
import soundfile as sf
from azure.core.credentials import AzureKeyCredential
from dotenv import load_dotenv

from rtclient import InputTextContentPart, RTClient, RTInputItem, RTOutputItem, RTResponse, UserMessageItem

start_time = time.time()


def log(message):
elapsed_time_ms = int((time.time() - start_time) * 1000)
log(f"{elapsed_time_ms} [ms]: {message}", flush=True)


async def receive_control(client: RTClient):
async for control in client.control_messages():
if control is not None:
log(f"Received a control message: {control.type}")
else:
break


async def receive_item(item: RTOutputItem, out_dir: str):
prefix = f"[response={item.response_id}][item={item.id}]"
audio_data = None
audio_transcript = None
text_data = None
arguments = None
async for chunk in item:
if chunk.type == "audio_transcript":
audio_transcript = (audio_transcript or "") + chunk.data
elif chunk.type == "audio":
if audio_data is None:
audio_data = bytearray()
audio_bytes = base64.b64decode(chunk.data)
audio_data.extend(audio_bytes)
elif chunk.type == "tool_call_arguments":
arguments = (arguments or "") + chunk.data
elif chunk.type == "text":
text_data = (text_data or "") + chunk.data
if text_data is not None:
log(prefix, f"Text: {text_data}")
with open(os.path.join(out_dir, f"{item.id}.text.txt"), "w", encoding="utf-8") as out:
out.write(text_data)
if audio_data is not None:
log(prefix, f"Audio received with length: {len(audio_data)}")
with open(os.path.join(out_dir, f"{item.id}.wav"), "wb") as out:
audio_array = np.frombuffer(audio_data, dtype=np.int16)
sf.write(out, audio_array, samplerate=24000)
if audio_transcript is not None:
log(prefix, f"Audio Transcript: {audio_transcript}")
with open(os.path.join(out_dir, f"{item.id}.audio_transcript.txt"), "w", encoding="utf-8") as out:
out.write(audio_transcript)
if arguments is not None:
log(prefix, f"Tool Call Arguments: {arguments}")
with open(os.path.join(out_dir, f"{item.id}.tool.streamed.json"), "w", encoding="utf-8") as out:
out.write(arguments)


async def receive_response(client: RTClient, response: RTResponse, out_dir: str):
prefix = f"[response={response.id}]"
async for item in response:
log(prefix, f"Received item {item.id}")
asyncio.create_task(receive_item(item, out_dir))
log(prefix, "Response completed")
await client.close()


async def receive_input_item(item: RTInputItem):
prefix = f"[input_item={item.id}]"
await item
log(prefix, f"Previous Id: {item.previous_id}")
log(prefix, f"Transcript: {item.transcript}")
log(prefix, f"Audio Start [ms]: {item.audio_start_ms}")
log(prefix, f"Audio End [ms]: {item.audio_end_ms}")


async def receive_items(client: RTClient, out_dir: str):
async for item in client.items():
if isinstance(item, RTResponse):
asyncio.create_task(receive_response(client, item, out_dir))
else:
asyncio.create_task(receive_input_item(item))


async def receive_messages(client: RTClient, out_dir: str):
await asyncio.gather(
receive_items(client, out_dir),
receive_control(client),
)


async def run(client: RTClient, instructions_file_path: str, user_message_file_path: str, out_dir: str):
with open(instructions_file_path) as instructions_file, open(user_message_file_path) as user_message_file:
instructions = instructions_file.read()
user_message = user_message_file.read()
log("Configuring Session...")
await client.configure(
instructions=instructions,
)
log("Done")
log("Sending User Message...")
await client.send_item(UserMessageItem(content=[InputTextContentPart(text=user_message)]))
log("Done")
await client.generate_response()
await receive_messages(client, out_dir)


def get_env_var(var_name: str) -> str:
value = os.environ.get(var_name)
if not value:
raise OSError(f"Environment variable '{var_name}' is not set or is empty.")
return value


async def with_azure_openai(instructions_file_path: str, user_message_file_path: str, out_dir: str):
endpoint = get_env_var("AZURE_OPENAI_ENDPOINT")
key = get_env_var("AZURE_OPENAI_API_KEY")
deployment = get_env_var("AZURE_OPENAI_DEPLOYMENT")
async with RTClient(url=endpoint, key_credential=AzureKeyCredential(key), azure_deployment=deployment) as client:
await run(client, instructions_file_path, user_message_file_path, out_dir)


async def with_openai(instructions_file_path: str, user_message_file_path: str, out_dir: str):
key = get_env_var("OPENAI_API_KEY")
model = get_env_var("OPENAI_MODEL")
async with RTClient(key_credential=AzureKeyCredential(key), model=model) as client:
await run(client, instructions_file_path, user_message_file_path, out_dir)


if __name__ == "__main__":
load_dotenv()
if len(sys.argv) < 3:
log(f"Usage: python {sys.argv[0]} <instructions_file> <message_file> <out_dir> [azure|openai]")
log("If the last argument is not provided, it will default to azure")
sys.exit(1)

instructions_file_path = sys.argv[1]
user_message_file_path = sys.argv[2]
out_dir = sys.argv[3]
provider = sys.argv[4] if len(sys.argv) == 4 else "azure"

if not os.path.isfile(instructions_file_path):
log(f"File {instructions_file_path} does not exist")
sys.exit(1)

if not os.path.isfile(user_message_file_path):
log(f"File {user_message_file_path} does not exist")
sys.exit(1)

if not os.path.isdir(out_dir):
log(f"Directory {out_dir} does not exist")
sys.exit(1)

if provider not in ["azure", "openai"]:
log(f"Provider {provider} needs to be one of 'azure' or 'openai'")
sys.exit(1)

if provider == "azure":
asyncio.run(with_azure_openai(instructions_file_path, user_message_file_path, out_dir))
else:
asyncio.run(with_openai(instructions_file_path, user_message_file_path, out_dir))

0 comments on commit 21764a7

Please sign in to comment.