Skip to content

Commit

Permalink
Handle streamed function calls (#1118)
Browse files Browse the repository at this point in the history
* Handle streamed function calls

* apply black formatting

* rm unnecessary stdout print

* bug fix

---------

Co-authored-by: Davor Runje <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2024
1 parent 1c4ae3d commit 78a2d84
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,22 +287,42 @@ def yes_or_no_filter(context, response):

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions or tools, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params and "tools" not in params:
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0

# Set the terminal text color to green
print("\033[32m", end="")

# Prepare for potential function call
full_function_call = None
# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
function_call_chunk = choice.delta.function_call
finish_reasons[choice.index] = choice.finish_reason

# Handle function call
if function_call_chunk:
if hasattr(function_call_chunk, "name") and function_call_chunk.name:
if full_function_call is None:
full_function_call = {"name": "", "arguments": ""}
full_function_call["name"] += function_call_chunk.name
completion_tokens += 1
if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
full_function_call["arguments"] += function_call_chunk.arguments
completion_tokens += 1
if choice.finish_reason == "function_call":
# Need something here? I don't think so.
pass
if not content:
continue
# End handle function call

# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
Expand Down Expand Up @@ -336,7 +356,7 @@ def _completions_create(self, client, params):
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
role="assistant", content=response_contents[i], function_call=full_function_call
),
logprobs=None,
)
Expand All @@ -346,17 +366,17 @@ def _completions_create(self, client, params):
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
role="assistant", content=response_contents[i], function_call=full_function_call
),
)

response.choices.append(choice)
else:
# If streaming is not enabled, using functions, or tools, send a regular chat completion request
# Functions and Tools are not supported, so ensure streaming is disabled
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
params["stream"] = False
response = completions.create(**params)

return response

def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
Expand Down

0 comments on commit 78a2d84

Please sign in to comment.