Skip to content

Commit

Permalink
Add state store for Autogen (#150)
Browse files Browse the repository at this point in the history
* added state store

* added expiry for in-memory state store

* remove test.py

* type annotations

* -

* Update notebook sample

* Update README

* Run formatter

---------

Co-authored-by: Kristian Nylund <[email protected]>
Co-authored-by: Ben Constable <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent c416cc7 commit aa36e62
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"source": [
"import dotenv\n",
"import logging\n",
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload"
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n",
"from autogen_text_2_sql.state_store import InMemoryStateStore"
]
},
{
Expand Down Expand Up @@ -86,16 +87,10 @@
"metadata": {},
"outputs": [],
"source": [
"agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")"
"# The state store allows AutoGen to store the states in memory across invocation. Whilst not neccessary, you can replace it with your own implementation that is backed by a database or file system. \n",
"agentic_text_2_sql = AutoGenText2Sql(state_store=InMemoryStateStore(), use_case=\"Analysing sales data\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -109,7 +104,7 @@
"metadata": {},
"outputs": [],
"source": [
"async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"what are the total sales\")):\n",
"async for message in agentic_text_2_sql.process_user_message(thread_id=\"1\", message_payload=UserMessagePayload(user_message=\"what are the total sales\")):\n",
" logging.info(\"Received %s Message from Text2SQL System\", message)"
]
},
Expand Down Expand Up @@ -137,7 +132,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
4 changes: 4 additions & 0 deletions text_2_sql/autogen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ Contains specialized agent implementations:
- **sql_schema_selection_agent.py:** Handles schema selection and management
- **answer_and_sources_agent.py:** Formats and standardizes final outputs

## State Store

To enable the [AutoGen State](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.state.html) to be tracked across invocations, a state store implementation must be provided. A basic `InMemoryStateStore` is provided, but this can be replaced with an implementation for a database or file system for when the Agentic System is running behind an API. This enables the AutoGen state to be saved behind the scenes and recalled later when the message is part of the same thread. A `thread_id` must be provided to the entrypoint.

## Configuration

The system behavior can be controlled through environment variables:
Expand Down
1 change: 1 addition & 0 deletions text_2_sql/autogen/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"text_2_sql_core",
"sqlparse>=0.4.4",
"nltk>=3.8.1",
"cachetools>=5.5.1",
]

[dependency-groups]
Expand Down
3 changes: 3 additions & 0 deletions text_2_sql/autogen/src/autogen_text_2_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
from autogen_text_2_sql.state_store import InMemoryStateStore

from text_2_sql_core.payloads.interaction_payloads import (
UserMessagePayload,
DismabiguationRequestsPayload,
Expand All @@ -16,4 +18,5 @@
"AnswerWithSourcesPayload",
"ProcessingUpdatePayload",
"InteractionPayload",
"InMemoryStateStore",
]
31 changes: 12 additions & 19 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
ParallelQuerySolvingAgent,
)
from autogen_text_2_sql.state_store import StateStore
from autogen_agentchat.messages import TextMessage
import json
import os
Expand All @@ -31,9 +32,13 @@


class AutoGenText2Sql:
def __init__(self, **kwargs):
def __init__(self, state_store: StateStore, **kwargs):
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()

if not state_store:
raise ValueError("State store must be provided")
self.state_store = state_store

if "use_case" not in kwargs:
logging.warning(
"No use case provided. It is advised to provide a use case to help the LLM reason."
Expand Down Expand Up @@ -250,43 +255,31 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:

async def process_user_message(
self,
thread_id: str,
message_payload: UserMessagePayload,
chat_history: list[InteractionPayload] = None,
) -> AsyncGenerator[InteractionPayload, None]:
"""Process the complete message through the unified system.
Args:
----
thread_id (str): The ID of the thread the message belongs to.
task (str): The user message to process.
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
Returns:
-------
dict: The response from the system.
"""
logging.info("Processing message: %s", message_payload.body.user_message)
logging.info("Chat history: %s", chat_history)

agent_input = {
"message": message_payload.body.user_message,
"injected_parameters": message_payload.body.injected_parameters,
}

latest_state = None
if chat_history is not None:
# Update input
for chat in reversed(chat_history):
if chat.root.payload_type in [
PayloadType.ANSWER_WITH_SOURCES,
PayloadType.DISAMBIGUATION_REQUESTS,
]:
latest_state = chat.body.assistant_state
break

# TODO: Trim the chat history to the last message from the user
if latest_state is not None:
await self.agentic_flow.load_state(latest_state)
state = self.state_store.get_state(thread_id)
if state is not None:
await self.agentic_flow.load_state(state)

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)
Expand Down Expand Up @@ -340,7 +333,7 @@ async def process_user_message(
):
# Get the state
assistant_state = await self.agentic_flow.save_state()
payload.body.assistant_state = assistant_state
self.state_store.save_state(thread_id, assistant_state)

logging.debug("Final Payload: %s", payload)

Expand Down
23 changes: 23 additions & 0 deletions text_2_sql/autogen/src/autogen_text_2_sql/state_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from cachetools import TTLCache


class StateStore(ABC):
@abstractmethod
def get_state(self, thread_id):
pass

@abstractmethod
def save_state(self, thread_id, state):
pass


class InMemoryStateStore(StateStore):
def __init__(self):
self.cache = TTLCache(maxsize=1000, ttl=4 * 60 * 60) # 4 hours

def get_state(self, thread_id: str) -> dict:
return self.cache.get(thread_id)

def save_state(self, thread_id: str, state: dict) -> None:
self.cache[thread_id] = state
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class DismabiguationRequest(InteractionPayloadBase):
decomposed_user_messages: list[list[str]] = Field(
default_factory=list, alias="decomposedUserMessages"
)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
Expand Down Expand Up @@ -86,7 +85,6 @@ class Source(InteractionPayloadBase):
default_factory=list, alias="decomposedUserMessages"
)
sources: list[Source] = Field(default_factory=list)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"
Expand Down
21 changes: 16 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit aa36e62

Please sign in to comment.