Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save/load test for dotnet agents #5284

Merged
merged 14 commits into from
Feb 6, 2025
49 changes: 38 additions & 11 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
return ValueTask.FromResult(agent);
});

// Ensure the agent is actually created
// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
Expand Down Expand Up @@ -148,23 +148,50 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
}

[Fact]
public async Task AgentShouldSaveStateCorrectlyTest()
public async Task AgentShouldSaveLoadStateCorrectlyTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();

Logger<BaseAgent> logger = new(new LoggerFactory());
TestAgent agent = new TestAgent(new AgentId("TestType", "TestKey"), runtime, logger);
SubscribedSaveLoadAgent agent = null!;

await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
{
agent = new SubscribedSaveLoadAgent(id, runtime, logger);
return ValueTask.FromResult(agent);
});

// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");

await runtime.RegisterImplicitAgentSubscriptionsAsync<SubscribedSaveLoadAgent>("MyAgent");

var topicType = "TestTopic";

await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);

await runtime.RunUntilIdleAsync();

agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed.");

// Save the state
var savedState = await agent.SaveStateAsync();

// Ensure the state contains receivedMessages
savedState.Should().ContainKey("receivedMessages");
savedState["receivedMessages"].Should().BeOfType<Dictionary<string, object>>();

var state = await agent.SaveStateAsync();
// Create a new instance of the agent to simulate a restart
var newAgent = new SubscribedSaveLoadAgent(agent.Id, runtime, logger);

// Ensure state is a dictionary
state.Should().NotBeNull();
state.Should().BeOfType<Dictionary<string, object>>();
state.Should().BeEmpty("Default SaveStateAsync should return an empty dictionary.");
// Load the saved state into the new agent
await newAgent.LoadStateAsync(savedState);

// Add a sample value and verify it updates correctly
state["testKey"] = "testValue";
state.Should().ContainKey("testKey").WhoseValue.Should().Be("testValue");
// Verify that the loaded state contains the received message
newAgent.ReceivedMessages.Should().ContainKey(topicType).WhoseValue.Should().Be("test");
}
}
37 changes: 36 additions & 1 deletion dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ValueTask<string> HandleAsync(RpcTextMessage item, MessageContext message
/// Key: source
/// Value: message
/// </summary>
private readonly Dictionary<string, object> _receivedMessages = new();
protected Dictionary<string, object> _receivedMessages = new();
public Dictionary<string, object> ReceivedMessages => _receivedMessages;
}

Expand All @@ -72,3 +72,38 @@ public SubscribedAgent(AgentId id,
{
}
}

[TypeSubscription("TestTopic")]
public class SubscribedSaveLoadAgent : TestAgent
{
private const string SavedStateKey = "receivedMessages";

public SubscribedSaveLoadAgent(AgentId id,
IAgentRuntime runtime,
Logger<BaseAgent>? logger = null) : base(id, runtime, logger)
{
}

public override ValueTask<IDictionary<string, object>> SaveStateAsync()
{
return ValueTask.FromResult<IDictionary<string, object>>(new Dictionary<string, object>
{
{ SavedStateKey, new Dictionary<string, object>(_receivedMessages) } // Save _receivedMessages
});
}

public override ValueTask LoadStateAsync(IDictionary<string, object> state)
{
if (state.TryGetValue(SavedStateKey, out var loadedMessagesObj) &&
loadedMessagesObj is Dictionary<string, object> loadedMessages)
{
_receivedMessages.Clear();
foreach (var kvp in loadedMessages)
{
_receivedMessages[kvp.Key] = kvp.Value;
}
}

return ValueTask.CompletedTask;
}
}
Loading