Skip to content

Commit

Permalink
Merge branch 'main' into dotnet_unit
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang authored Jan 30, 2025
2 parents 28f953f + fca1de9 commit cf8c7d0
Show file tree
Hide file tree
Showing 15 changed files with 679 additions and 9 deletions.
14 changes: 14 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Contracts/AgentId.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text.RegularExpressions;

namespace Microsoft.AutoGen.Contracts;

Expand All @@ -16,6 +17,9 @@ namespace Microsoft.AutoGen.Contracts;
[DebuggerDisplay($"AgentId(type=\"{nameof(Type)}\", key=\"{nameof(Key)}\")")]
public struct AgentId
{
private static readonly Regex TypeRegex = new(@"^[a-zA-Z_][a-zA-Z0-9_]*$", RegexOptions.Compiled);
private static readonly Regex KeyRegex = new(@"^[\x20-\x7E]+$", RegexOptions.Compiled); // ASCII 32-126

/// <summary>
/// An identifier that associates an agent with a specific factory function.
/// Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_).
Expand All @@ -35,6 +39,16 @@ public struct AgentId
/// <param name="key">Agent instance identifier.</param>
public AgentId(string type, string key)
{
if (string.IsNullOrWhiteSpace(type) || !TypeRegex.IsMatch(type))
{
throw new ArgumentException($"Invalid AgentId type: '{type}'. Must be alphanumeric (a-z, 0-9, _) and cannot start with a number or contain spaces.");
}

if (string.IsNullOrWhiteSpace(key) || !KeyRegex.IsMatch(key))
{
throw new ArgumentException($"Invalid AgentId key: '{key}'. Must only contain ASCII characters 32-126.");
}

Type = type;
Key = key;
}
Expand Down
109 changes: 109 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentIdTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentIdTests.cs
using FluentAssertions;
using Microsoft.AutoGen.Contracts;
using Xunit;

namespace Microsoft.AutoGen.Core.Tests;

public class AgentIdTests()
{
[Fact]
public void AgentIdShouldInitializeCorrectlyTest()
{
var agentId = new AgentId("TestType", "TestKey");

agentId.Type.Should().Be("TestType");
agentId.Key.Should().Be("TestKey");
}

[Fact]
public void AgentIdShouldConvertFromTupleTest()
{
var agentTuple = ("TupleType", "TupleKey");
var agentId = new AgentId(agentTuple);

agentId.Type.Should().Be("TupleType");
agentId.Key.Should().Be("TupleKey");
}

[Fact]
public void AgentIdShouldParseFromStringTest()
{
var agentId = AgentId.FromStr("ParsedType/ParsedKey");

agentId.Type.Should().Be("ParsedType");
agentId.Key.Should().Be("ParsedKey");
}

[Fact]
public void AgentIdShouldCompareEqualityCorrectlyTest()
{
var agentId1 = new AgentId("SameType", "SameKey");
var agentId2 = new AgentId("SameType", "SameKey");
var agentId3 = new AgentId("DifferentType", "DifferentKey");

agentId1.Should().Be(agentId2);
agentId1.Should().NotBe(agentId3);
(agentId1 == agentId2).Should().BeTrue();
(agentId1 != agentId3).Should().BeTrue();
}

[Fact]
public void AgentIdShouldGenerateCorrectHashCodeTest()
{
var agentId1 = new AgentId("HashType", "HashKey");
var agentId2 = new AgentId("HashType", "HashKey");
var agentId3 = new AgentId("DifferentType", "DifferentKey");

agentId1.GetHashCode().Should().Be(agentId2.GetHashCode());
agentId1.GetHashCode().Should().NotBe(agentId3.GetHashCode());
}

[Fact]
public void AgentIdShouldConvertExplicitlyFromStringTest()
{
var agentId = (AgentId)"ConvertedType/ConvertedKey";

agentId.Type.Should().Be("ConvertedType");
agentId.Key.Should().Be("ConvertedKey");
}

[Fact]
public void AgentIdShouldReturnCorrectToStringTest()
{
var agentId = new AgentId("ToStringType", "ToStringKey");

agentId.ToString().Should().Be("ToStringType/ToStringKey");
}

[Fact]
public void AgentIdShouldCompareInequalityCorrectlyTest()
{
var agentId1 = new AgentId("Type1", "Key1");
var agentId2 = new AgentId("Type2", "Key2");

(agentId1 != agentId2).Should().BeTrue();
}

[Fact]
public void AgentIdShouldRejectInvalidNamesTest()
{
// Invalid: 'Type' cannot start with a number and must only contain a-z, 0-9, or underscores.
Action invalidType = () => new AgentId("123InvalidType", "ValidKey");
invalidType.Should().Throw<ArgumentException>("Agent type cannot start with a number and must only contain alphanumeric letters or underscores.");

Action invalidTypeWithSpaces = () => new AgentId("Invalid Type", "ValidKey");
invalidTypeWithSpaces.Should().Throw<ArgumentException>("Agent type cannot contain spaces.");

Action invalidTypeWithSpecialChars = () => new AgentId("Invalid@Type", "ValidKey");
invalidTypeWithSpecialChars.Should().Throw<ArgumentException>("Agent type cannot contain special characters.");

// Invalid: 'Key' must contain only ASCII characters 32 (space) to 126 (~).
Action invalidKey = () => new AgentId("ValidType", "InvalidKey💀");
invalidKey.Should().Throw<ArgumentException>("Agent key must only contain ASCII characters between 32 (space) and 126 (~).");

Action validCase = () => new AgentId("Valid_Type", "Valid_Key_123");
validCase.Should().NotThrow("This is a correctly formatted AgentId.");
}
}
20 changes: 20 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentMetaDataTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentMetaDataTests.cs
using FluentAssertions;
using Microsoft.AutoGen.Contracts;
using Xunit;

namespace Microsoft.AutoGen.Core.Tests;

public class AgentMetadataTests()
{
[Fact]
public void AgentMetadataShouldInitializeCorrectlyTest()
{
var metadata = new AgentMetadata("TestType", "TestKey", "TestDescription");

metadata.Type.Should().Be("TestType");
metadata.Key.Should().Be("TestKey");
metadata.Description.Should().Be("TestDescription");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class ModelFamily:
O1 = "o1"
GPT_4 = "gpt-4"
GPT_35 = "gpt-35"
R1 = "r1"
UNKNOWN = "unknown"

ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]

def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
Expand Down
35 changes: 35 additions & 0 deletions python/packages/autogen-core/src/autogen_core/models/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@


class SystemMessage(BaseModel):
"""System message contains instructions for the model coming from the developer.
.. note::
Open AI is moving away from using 'system' role in favor of 'developer' role.
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
on the server side.
So, you can use `SystemMessage` for developer messages.
"""

content: str
type: Literal["SystemMessage"] = "SystemMessage"


class UserMessage(BaseModel):
"""User message contains input from end users, or a catch-all for data provided to the model."""

content: Union[str, List[Union[str, Image]]]

# Name of the agent that sent this message
Expand All @@ -22,6 +36,8 @@ class UserMessage(BaseModel):


class AssistantMessage(BaseModel):
"""Assistant message are sampled from the language model."""

content: Union[str, List[FunctionCall]]

# Name of the agent that sent this message
Expand All @@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):


class FunctionExecutionResult(BaseModel):
"""Function execution result contains the output of a function call."""

content: str
call_id: str


class FunctionExecutionResultMessage(BaseModel):
"""Function execution result message contains the output of multiple function calls."""

content: List[FunctionExecutionResult]

type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
Expand Down Expand Up @@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):


class CreateResult(BaseModel):
"""Create result contains the output of a model completion."""

finish_reason: FinishReasons
"""The reason the model finished generating the completion."""

content: Union[str, List[FunctionCall]]
"""The output of the model completion."""

usage: RequestUsage
"""The usage of tokens in the prompt and completion."""

cached: bool
"""Whether the completion was generated from a cached response."""

logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
"""The logprobs of the tokens in the completion."""

thought: Optional[str] = None
"""The reasoning text for the completion if available. Used for reasoning models
and additional text content besides function calls."""
1 change: 1 addition & 0 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ dev = [
"autogen_test_utils",
"langchain-experimental",
"pandas-stubs>=2.2.3.241126",
"httpx>=0.28.1",
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import warnings
from typing import Tuple


def parse_r1_content(content: str) -> Tuple[str | None, str]:
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
# Find the start and end of the think field
think_start = content.find("<think>")
think_end = content.find("</think>")

if think_start == -1 or think_end == -1:
warnings.warn(
"Could not find <think>..</think> field in model response content. " "No thought was extracted.",
UserWarning,
stacklevel=2,
)
return None, content

if think_end < think_start:
warnings.warn(
"Found </think> before <think> in model response content. " "No thought was extracted.",
UserWarning,
stacklevel=2,
)
return None, content

# Extract the think field
thought = content[think_start + len("<think>") : think_end].strip()

# Extract the rest of the content, skipping the think field.
content = content[think_end + len("</think>") :].strip()

return thought, content
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelFamily,
ModelInfo,
RequestUsage,
SystemMessage,
Expand Down Expand Up @@ -55,6 +56,8 @@
AzureAIChatCompletionClientConfig,
)

from .._utils.parse_r1_content import parse_r1_content

create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]

Expand Down Expand Up @@ -354,11 +357,17 @@ async def create(
finish_reason = choice.finish_reason # type: ignore
content = choice.message.content or ""

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

response = CreateResult(
finish_reason=finish_reason, # type: ignore
content=content,
usage=usage,
cached=False,
thought=thought,
)

self.add_usage(usage)
Expand Down Expand Up @@ -464,11 +473,17 @@ async def create_stream(
prompt_tokens=prompt_tokens,
)

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

result = CreateResult(
finish_reason=finish_reason,
content=content,
usage=usage,
cached=False,
thought=thought,
)

self.add_usage(usage)
Expand Down
Loading

0 comments on commit cf8c7d0

Please sign in to comment.