Skip to content

Commit

Permalink
Update interaction payloads (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenConstable9 authored Jan 28, 2025
1 parent 40e21ed commit c416cc7
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 21 deletions.
6 changes: 5 additions & 1 deletion text_2_sql/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ Text2Sql__Tsql__ConnectionString=<Tsql databaseConnectionString if using Tsql Da
Text2Sql__Tsql__Database=<Tsql database if using Tsql Data Source>

# PostgreSQL Specific Connection Details
Text2Sql__Postgresql__ConnectionString=<Postgresql databaseConnectionString if using Postgresql Data Source>
Text2Sql__Postgresql__ConnectionString=<Postgresql databaseConnectionString if using Postgresql Data Source and a connection string>
Text2Sql__Postgresql__Database=<Postgresql database if using Postgresql Data Source>
Text2Sql__Postgresql__User=<Postgresql user if using Postgresql Data Source and not the connections string>
Text2Sql__Postgresql__Password=<Postgresql password if using Postgresql Data Source and not the connections string>
Text2Sql__Postgresql__ServerHostname=<Postgresql serverHostname if using Postgresql Data Source and not the connections string>
Text2Sql__Postgresql__Port=<Postgresql port if using Postgresql Data Source and not the connections string>

# Snowflake Specific Connection Details
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>
Expand Down
48 changes: 39 additions & 9 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self, **kwargs):

self.kwargs = {**DEFAULT_INJECTED_PARAMETERS, **kwargs}

self._agentic_flow = None

def get_all_agents(self):
"""Get all agents for the complete flow."""

Expand Down Expand Up @@ -97,14 +99,20 @@ def unified_selector(self, messages):
@property
def agentic_flow(self):
"""Create the unified flow for the complete process."""

if self._agentic_flow is not None:
return self._agentic_flow

flow = SelectorGroupChat(
self.get_all_agents(),
allow_repeated_speaker=False,
model_client=LLMModelCreator.get_model("4o-mini"),
termination_condition=self.termination_condition,
selector_func=self.unified_selector,
)
return flow

self._agentic_flow = flow
return self._agentic_flow

def parse_message_content(self, content):
"""Parse different message content formats into a dictionary."""
Expand Down Expand Up @@ -250,7 +258,7 @@ async def process_user_message(
Args:
----
task (str): The user message to process.
chat_history (list[str], optional): The chat history. Defaults to None.
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
Returns:
Expand All @@ -262,17 +270,23 @@ async def process_user_message(

agent_input = {
"message": message_payload.body.user_message,
"chat_history": {},
"injected_parameters": message_payload.body.injected_parameters,
}

latest_state = None
if chat_history is not None:
# Update input
for idx, chat in enumerate(chat_history):
if chat.root.payload_type == PayloadType.USER_MESSAGE:
# For now only consider the user query
chat_history_key = f"chat_{idx}"
agent_input[chat_history_key] = chat.root.body.user_message
for chat in reversed(chat_history):
if chat.root.payload_type in [
PayloadType.ANSWER_WITH_SOURCES,
PayloadType.DISAMBIGUATION_REQUESTS,
]:
latest_state = chat.body.assistant_state
break

# TODO: Trim the chat history to the last message from the user
if latest_state is not None:
await self.agentic_flow.load_state(latest_state)

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)
Expand Down Expand Up @@ -312,6 +326,22 @@ async def process_user_message(
logging.error("Unexpected TaskResult: %s", message)
raise ValueError("Unexpected TaskResult")

if payload is not None:
if (
payload is not None
and payload.payload_type is PayloadType.PROCESSING_UPDATE
):
logging.debug("Payload: %s", payload)
yield payload

# Return the final payload
if (
payload is not None
and payload.payload_type is not PayloadType.PROCESSING_UPDATE
):
# Get the state
assistant_state = await self.agentic_flow.save_state()
payload.body.assistant_state = assistant_state

logging.debug("Final Payload: %s", payload)

yield payload
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import logging
import json

from urllib.parse import urlparse
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


Expand Down Expand Up @@ -66,10 +66,35 @@ async def query_execution(
"""
logging.info(f"Running query: {sql_query}")
results = []
connection_string = os.environ["Text2Sql__Postgresql__ConnectionString"]

if "Text2Sql__Postgresql__ConnectionString" in os.environ:
logging.info("Postgresql Connection string found in environment variables.")

p = urlparse(os.environ["Text2Sql__Postgresql__ConnectionString"])

postgres_connections = {
"dbname": p.path[1:],
"user": p.username,
"password": p.password,
"port": p.port,
"host": p.hostname,
}
else:
logging.warning(
"Postgresql Connection string not found in environment variables. Using individual variables."
)
postgres_connections = {
"dbname": os.environ["Text2Sql__Postgresql__Database"],
"user": os.environ["Text2Sql__Postgresql__User"],
"password": os.environ["Text2Sql__Postgresql__Password"],
"port": os.environ["Text2Sql__Postgresql__Port"],
"host": os.environ["Text2Sql__Postgresql__ServerHostname"],
}

# Establish an asynchronous connection to the PostgreSQL database
async with await psycopg.AsyncConnection.connect(connection_string) as conn:
async with await psycopg.AsyncConnection.connect(
**postgres_connections
) as conn:
# Create an asynchronous cursor
async with conn.cursor() as cursor:
await cursor.execute(sql_query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class PayloadSource(StrEnum):
USER = "user"
AGENT = "agent"
ASSISTANT = "assistant"


class PayloadType(StrEnum):
Expand All @@ -42,11 +42,13 @@ class PayloadBase(InteractionPayloadBase):
payload_type: PayloadType = Field(..., alias="payloadType")
payload_source: PayloadSource = Field(..., alias="payloadSource")

body: InteractionPayloadBase | None = Field(default=None)


class DismabiguationRequestsPayload(InteractionPayloadBase):
class Body(InteractionPayloadBase):
class DismabiguationRequest(InteractionPayloadBase):
agent_question: str | None = Field(..., alias="agentQuestion")
ASSISTANT_question: str | None = Field(..., alias="ASSISTANTQuestion")
user_choices: list[str] | None = Field(default=None, alias="userChoices")

disambiguation_requests: list[DismabiguationRequest] | None = Field(
Expand All @@ -55,12 +57,13 @@ class DismabiguationRequest(InteractionPayloadBase):
decomposed_user_messages: list[list[str]] = Field(
default_factory=list, alias="decomposedUserMessages"
)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
)
payload_source: Literal[PayloadSource.AGENT] = Field(
default=PayloadSource.AGENT, alias="payloadSource"
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
default=PayloadSource.ASSISTANT, alias="payloadSource"
)
body: Body | None = Field(default=None)

Expand All @@ -83,12 +86,13 @@ class Source(InteractionPayloadBase):
default_factory=list, alias="decomposedUserMessages"
)
sources: list[Source] = Field(default_factory=list)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"
)
payload_source: Literal[PayloadSource.AGENT] = Field(
PayloadSource.AGENT, alias="payloadSource"
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
PayloadSource.ASSISTANT, alias="payloadSource"
)
body: Body | None = Field(default=None)

Expand All @@ -108,8 +112,8 @@ class Body(InteractionPayloadBase):
payload_type: Literal[PayloadType.PROCESSING_UPDATE] = Field(
PayloadType.PROCESSING_UPDATE, alias="payloadType"
)
payload_source: Literal[PayloadSource.AGENT] = Field(
PayloadSource.AGENT, alias="payloadSource"
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
PayloadSource.ASSISTANT, alias="payloadSource"
)
body: Body | None = Field(default=None)

Expand Down

0 comments on commit c416cc7

Please sign in to comment.