Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save/load test for dotnet agents #5284

Merged
merged 14 commits into from
Feb 6, 2025
6 changes: 4 additions & 2 deletions dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentProxy.cs

using System.Text.Json;

namespace Microsoft.AutoGen.Contracts;

/// <summary>
Expand Down Expand Up @@ -55,7 +57,7 @@ private T ExecuteAndUnwrap<T>(Func<IAgentRuntime, ValueTask<T>> delegate_)
/// </summary>
/// <param name="state">A dictionary representing the state of the agent. Must be JSON serializable.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public ValueTask LoadStateAsync(IDictionary<string, object> state)
public ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
return this.runtime.LoadAgentStateAsync(this.Id, state);
}
Expand All @@ -64,7 +66,7 @@ public ValueTask LoadStateAsync(IDictionary<string, object> state)
/// Saves the state of the agent. The result must be JSON serializable.
/// </summary>
/// <returns>A task representing the asynchronous operation, returning a dictionary containing the saved state.</returns>
public ValueTask<IDictionary<string, object>> SaveStateAsync()
public ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
return this.runtime.SaveAgentStateAsync(this.Id);
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentRuntime.cs

using StateDict = System.Collections.Generic.IDictionary<string, object>;
using StateDict = System.Collections.Generic.IDictionary<string, System.Text.Json.JsonElement>;

namespace Microsoft.AutoGen.Contracts;

Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISaveState.cs

using StateDict = System.Collections.Generic.IDictionary<string, object>;
using StateDict = System.Collections.Generic.IDictionary<string, System.Text.Json.JsonElement>;

namespace Microsoft.AutoGen.Contracts;

Expand Down
39 changes: 22 additions & 17 deletions dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// GrpcAgentRuntime.cs

using System.Collections.Concurrent;
using System.Text.Json;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Protobuf;
Expand Down Expand Up @@ -319,13 +320,13 @@ public async ValueTask PublishMessageAsync(object message, TopicId topic, Contra
public ValueTask<Contracts.AgentId> GetAgentAsync(string agent, string key = "default", bool lazy = true)
=> this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy);

public async ValueTask<IDictionary<string, object>> SaveAgentStateAsync(Contracts.AgentId agentId)
public async ValueTask<IDictionary<string, JsonElement>> SaveAgentStateAsync(Contracts.AgentId agentId)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
return await agent.SaveStateAsync();
}

public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, object> state)
public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, JsonElement> state)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(state);
Expand Down Expand Up @@ -375,37 +376,41 @@ public ValueTask<AgentProxy> TryGetAgentProxyAsync(Contracts.AgentId agentId)
return ValueTask.FromResult(new AgentProxy(agentId, this));
}

public async ValueTask<IDictionary<string, object>> SaveStateAsync()
{
Dictionary<string, object> state = new();
foreach (var agent in this._agentsContainer.LiveAgents)
{
state[agent.Id.ToString()] = await agent.SaveStateAsync();
}

return state;
}

public async ValueTask LoadStateAsync(IDictionary<string, object> state)
public async ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
HashSet<AgentType> registeredTypes = this._agentsContainer.RegisteredAgentTypes;

foreach (var agentIdStr in state.Keys)
{
Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr);
if (state[agentIdStr] is not IDictionary<string, object> agentStateDict)

if (state[agentIdStr].ValueKind != JsonValueKind.Object)
{
throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary<string, object>)}: {state[agentIdStr].GetType()}");
throw new Exception($"Agent state for {agentId} is not a valid JSON object.");
}

var agentState = JsonSerializer.Deserialize<IDictionary<string, JsonElement>>(state[agentIdStr].GetRawText())
?? throw new Exception($"Failed to deserialize state for {agentId}.");

if (registeredTypes.Contains(agentId.Type))
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(agentStateDict);
await agent.LoadStateAsync(agentState);
}
}
}

public async ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
Dictionary<string, JsonElement> state = new();
foreach (var agent in this._agentsContainer.LiveAgents)
{
var agentState = await agent.SaveStateAsync();
state[agent.Id.ToString()] = JsonSerializer.SerializeToElement(agentState); // Ensure JSON serialization
}
return state;
}

public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default)
{
switch (message.MessageCase)
Expand Down
7 changes: 4 additions & 3 deletions dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Diagnostics;
using System.Reflection;
using System.Text.Json;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -92,11 +93,11 @@ private Dictionary<Type, HandlerInvoker> ReflectInvokers()
return null;
}

public virtual ValueTask<IDictionary<string, object>> SaveStateAsync()
public virtual ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
return ValueTask.FromResult<IDictionary<string, object>>(new Dictionary<string, object>());
return ValueTask.FromResult<IDictionary<string, JsonElement>>(new Dictionary<string, JsonElement>());
}
public virtual ValueTask LoadStateAsync(IDictionary<string, object> state)
public virtual ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
return ValueTask.CompletedTask;
}
Expand Down
24 changes: 15 additions & 9 deletions dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Concurrent;
using System.Diagnostics;
using System.Text.Json;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Hosting;

Expand Down Expand Up @@ -152,13 +153,13 @@ public async ValueTask<AgentMetadata> GetAgentMetadataAsync(AgentId agentId)
return agent.Metadata;
}

public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary<string, object> state)
public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary<string, JsonElement> state)
{
IHostableAgent agent = await this.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(state);
}

public async ValueTask<IDictionary<string, object>> SaveAgentStateAsync(AgentId agentId)
public async ValueTask<IDictionary<string, JsonElement>> SaveAgentStateAsync(AgentId agentId)
{
IHostableAgent agent = await this.EnsureAgentAsync(agentId);
return await agent.SaveStateAsync();
Expand Down Expand Up @@ -187,16 +188,21 @@ public ValueTask RemoveSubscriptionAsync(string subscriptionId)
return ValueTask.CompletedTask;
}

public async ValueTask LoadStateAsync(IDictionary<string, object> state)
public async ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
foreach (var agentIdStr in state.Keys)
{
AgentId agentId = AgentId.FromStr(agentIdStr);
if (state[agentIdStr] is not IDictionary<string, object> agentState)

if (state[agentIdStr].ValueKind != JsonValueKind.Object)
{
throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary<string, object>)}: {state[agentIdStr].GetType()}");
throw new Exception($"Agent state for {agentId} is not a valid JSON object.");
}

// Deserialize before using
var agentState = JsonSerializer.Deserialize<IDictionary<string, JsonElement>>(state[agentIdStr].GetRawText())
?? throw new Exception($"Failed to deserialize state for {agentId}.");

if (this.agentFactories.ContainsKey(agentId.Type))
{
IHostableAgent agent = await this.EnsureAgentAsync(agentId);
Expand All @@ -205,14 +211,14 @@ public async ValueTask LoadStateAsync(IDictionary<string, object> state)
}
}

public async ValueTask<IDictionary<string, object>> SaveStateAsync()
public async ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
Dictionary<string, object> state = new();
Dictionary<string, JsonElement> state = new();
foreach (var agentId in this.agentInstances.Keys)
{
state[agentId.ToString()] = await this.agentInstances[agentId].SaveStateAsync();
var agentState = await this.agentInstances[agentId].SaveStateAsync();
state[agentId.ToString()] = JsonSerializer.SerializeToElement(agentState);
}

return state;
}

Expand Down
68 changes: 68 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentRuntimeTests.cs
using FluentAssertions;
using System.Text.Json;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Logging;
using Xunit;
Expand Down Expand Up @@ -80,4 +81,71 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
agent.Text.Source.Should().Be("TestTopic");
agent.Text.Content.Should().Be("SelfMessage");
}

[Fact]
public async Task RuntimeShouldSaveLoadStateCorrectlyTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();

Logger<BaseAgent> logger = new(new LoggerFactory());
SubscribedSaveLoadAgent agent = null!;

await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
{
agent = new SubscribedSaveLoadAgent(id, runtime, logger);
return ValueTask.FromResult(agent);
});

// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");

await runtime.RegisterImplicitAgentSubscriptionsAsync<SubscribedSaveLoadAgent>("MyAgent");

var topicType = "TestTopic";

await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);

await runtime.RunUntilIdleAsync();

agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed.");

// Save the state
var savedState = await runtime.SaveStateAsync();

// Ensure saved state contains the agent's state
savedState.Should().ContainKey(agentId.ToString());

// Ensure the agent's state is stored as a valid JSON object
savedState[agentId.ToString()].ValueKind.Should().Be(JsonValueKind.Object, "Agent state should be stored as a JSON object");

// Serialize and Deserialize the state to simulate persistence
string json = JsonSerializer.Serialize(savedState);
json.Should().NotBeNullOrEmpty("Serialized state should not be empty");

var deserializedState = JsonSerializer.Deserialize<IDictionary<string, JsonElement>>(json)
?? throw new Exception("Deserialized state is unexpectedly null");

deserializedState.Should().ContainKey(agentId.ToString());

// Load the saved state back into a new runtime instance
var newRuntime = new InProcessRuntime();
await newRuntime.StartAsync();
await newRuntime.LoadStateAsync(deserializedState);

// Ensure the agent exists in the new runtime
AgentId newAgentId = await newRuntime.GetAgentAsync("MyAgent", lazy: false);
newAgentId.Should().Be(agentId, "Loaded agent ID should match original agent ID");

// Retrieve the agent's saved state
var restoredState = await newRuntime.SaveAgentStateAsync(newAgentId);
restoredState.Should().ContainKey("TestTopic");

// Ensure "TestTopic" contains the correct message
restoredState["TestTopic"].ValueKind.Should().Be(JsonValueKind.String, "Expected 'TestTopic' to be a string");
restoredState["TestTopic"].GetString().Should().Be("test", "Agent state should contain the original message");
}
}
23 changes: 1 addition & 22 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
return ValueTask.FromResult(agent);
});

// Ensure the agent is actually created
// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
Expand Down Expand Up @@ -146,25 +146,4 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>

Assert.True(agent.ReceivedItems.Count == 1);
}

[Fact]
public async Task AgentShouldSaveStateCorrectlyTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();

Logger<BaseAgent> logger = new(new LoggerFactory());
TestAgent agent = new TestAgent(new AgentId("TestType", "TestKey"), runtime, logger);

var state = await agent.SaveStateAsync();

// Ensure state is a dictionary
state.Should().NotBeNull();
state.Should().BeOfType<Dictionary<string, object>>();
state.Should().BeEmpty("Default SaveStateAsync should return an empty dictionary.");

// Add a sample value and verify it updates correctly
state["testKey"] = "testValue";
state.Should().ContainKey("testKey").WhoseValue.Should().Be("testValue");
}
}
35 changes: 34 additions & 1 deletion dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Logging;
using System.Text.Json;

namespace Microsoft.AutoGen.Core.Tests;

Expand Down Expand Up @@ -59,7 +60,7 @@ public ValueTask<string> HandleAsync(RpcTextMessage item, MessageContext message
/// Key: source
/// Value: message
/// </summary>
private readonly Dictionary<string, object> _receivedMessages = new();
protected Dictionary<string, object> _receivedMessages = new();
public Dictionary<string, object> ReceivedMessages => _receivedMessages;
}

Expand All @@ -73,6 +74,38 @@ public SubscribedAgent(AgentId id,
}
}

[TypeSubscription("TestTopic")]
public class SubscribedSaveLoadAgent : TestAgent
{
public SubscribedSaveLoadAgent(AgentId id,
IAgentRuntime runtime,
Logger<BaseAgent>? logger = null) : base(id, runtime, logger)
{
}

public override ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
var jsonSafeDictionary = _receivedMessages.ToDictionary(
kvp => kvp.Key,
kvp => JsonSerializer.SerializeToElement(kvp.Value) // Convert each object to JsonElement
);

return ValueTask.FromResult<IDictionary<string, JsonElement>>(jsonSafeDictionary);
}

public override ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
_receivedMessages.Clear();

foreach (var kvp in state)
{
_receivedMessages[kvp.Key] = kvp.Value.Deserialize<object>() ?? throw new Exception($"Failed to deserialize key: {kvp.Key}");
}

return ValueTask.CompletedTask;
}
}

/// <summary>
/// The test agent showing an agent that subscribes to itself.
/// </summary>
Expand Down
Loading