Skip to content

Commit

Permalink
Merge branch 'main' into 4o-2024-11-20
Browse files Browse the repository at this point in the history
  • Loading branch information
rysweet authored Jan 15, 2025
2 parents 4faa9c4 + ae98c9d commit bb6c9fe
Show file tree
Hide file tree
Showing 27 changed files with 2,581 additions and 282 deletions.
1 change: 0 additions & 1 deletion python/packages/autogen-agentchat/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ classifiers = [
]
dependencies = [
"autogen-core==0.4.1",
"aioconsole>=0.8.1"
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from inspect import iscoroutinefunction
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast

from aioconsole import ainput # type: ignore
from autogen_core import CancellationToken

from ..base import Response
Expand All @@ -17,10 +16,9 @@
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]


# TODO: ainput doesn't seem to play nicely with jupyter.
# No input window appears in this case.
# TODO: check if using to_thread fixes this in jupyter
async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
task: asyncio.Task[str] = asyncio.create_task(ainput(prompt)) # type: ignore
task: asyncio.Task[str] = asyncio.create_task(asyncio.to_thread(input, prompt))
if cancellation_token is not None:
cancellation_token.link_future(task)
return await task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from inspect import iscoroutinefunction
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast

from aioconsole import aprint # type: ignore
from autogen_core import CancellationToken, Image
from autogen_core.models import RequestUsage

Expand Down Expand Up @@ -70,6 +69,10 @@ def notify_event_received(self, request_id: str) -> None:
self.input_events[request_id] = event


def aprint(output: str, end: str = "\n") -> Awaitable[None]:
return asyncio.to_thread(print, output, end=end)


async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
Expand Down
1 change: 1 addition & 0 deletions python/packages/autogen-core/docs/src/reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.tools.langchain
python/autogen_ext.tools.graphrag
python/autogen_ext.tools.code_execution
python/autogen_ext.code_executors.local
python/autogen_ext.code_executors.docker
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
autogen\_ext.tools.graphrag
===========================


.. automodule:: autogen_ext.tools.graphrag
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ file-surfer = [
"autogen-agentchat==0.4.1",
"markitdown>=0.0.1a2",
]
graphrag = ["graphrag>=1.0.1"]
web-surfer = [
"autogen-agentchat==0.4.1",
"playwright>=1.48.0",
Expand Down Expand Up @@ -57,6 +58,7 @@ packages = ["src/autogen_ext"]
dev = [
"autogen_test_utils",
"langchain-experimental",
"pandas-stubs>=2.2.3.241126",
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
Image,
MessageHandlerContext,
)
from autogen_core.models import FinishReasons
from autogen_core.logging import LLMCallEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
ChatCompletionTokenLogprob,
CreateResult,
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelCapabilities, # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ._config import (
GlobalContextConfig,
GlobalDataConfig,
LocalContextConfig,
LocalDataConfig,
MapReduceConfig,
SearchConfig,
)
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn

__all__ = [
"GlobalSearchTool",
"LocalSearchTool",
"GlobalDataConfig",
"LocalDataConfig",
"GlobalContextConfig",
"GlobalSearchToolArgs",
"GlobalSearchToolReturn",
"LocalContextConfig",
"LocalSearchToolArgs",
"LocalSearchToolReturn",
"MapReduceConfig",
"SearchConfig",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from pydantic import BaseModel


class DataConfig(BaseModel):
input_dir: str
entity_table: str = "create_final_nodes"
entity_embedding_table: str = "create_final_entities"
community_level: int = 2


class GlobalDataConfig(DataConfig):
community_table: str = "create_final_communities"
community_report_table: str = "create_final_community_reports"


class LocalDataConfig(DataConfig):
relationship_table: str = "create_final_relationships"
text_unit_table: str = "create_final_text_units"


class ContextConfig(BaseModel):
max_data_tokens: int = 8000


class GlobalContextConfig(ContextConfig):
use_community_summary: bool = False
shuffle_data: bool = True
include_community_rank: bool = True
min_community_rank: int = 0
community_rank_name: str = "rank"
include_community_weight: bool = True
community_weight_name: str = "occurrence weight"
normalize_community_weight: bool = True
max_data_tokens: int = 12000


class LocalContextConfig(ContextConfig):
text_unit_prop: float = 0.5
community_prop: float = 0.25
include_entity_rank: bool = True
rank_description: str = "number of relationships"
include_relationship_weight: bool = True
relationship_ranking_attribute: str = "rank"


class MapReduceConfig(BaseModel):
map_max_tokens: int = 1000
map_temperature: float = 0.0
reduce_max_tokens: int = 2000
reduce_temperature: float = 0.0
allow_general_knowledge: bool = False
json_mode: bool = False
response_type: str = "multiple paragraphs"


class SearchConfig(BaseModel):
max_tokens: int = 1500
temperature: float = 0.0
response_type: str = "multiple paragraphs"
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# mypy: disable-error-code="no-any-unimported,misc"
from pathlib import Path

import pandas as pd
import tiktoken
from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from graphrag.config.config_file_loader import load_config_from_file
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_entities,
read_indexer_reports,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.get_client import get_llm
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from pydantic import BaseModel, Field

from ._config import GlobalContextConfig as ContextConfig
from ._config import GlobalDataConfig as DataConfig
from ._config import MapReduceConfig

_default_context_config = ContextConfig()
_default_mapreduce_config = MapReduceConfig()


class GlobalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform global search on.")


class GlobalSearchToolReturn(BaseModel):
answer: str


class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
"""Enables running GraphRAG global search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines graph-based document relationships with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "autogen-agentchat" "autogen-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.ui import Console
from autogen_ext.tools.graphrag import GlobalSearchTool
from autogen_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up global search tool
global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml")
# Create assistant agent with the global search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[global_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
),
)
# Run a sample query
query = "What is the overall sentiment of the community reports?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
"""

def __init__(
self,
token_encoder: tiktoken.Encoding,
llm: BaseLLM,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
):
super().__init__(
args_type=GlobalSearchToolArgs,
return_type=GlobalSearchToolReturn,
name="global_search_tool",
description="Perform a global search with given parameters using graphrag.",
)
# Use the provided LLM
self._llm = llm

# Load parquet files
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
)
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
)

communities = read_indexer_communities(community_df, entity_df, report_df)
reports = read_indexer_reports(report_df, entity_df, data_config.community_level)
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level)

context_builder = GlobalCommunityContext(
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
)

context_builder_params = {
"use_community_summary": context_config.use_community_summary,
"shuffle_data": context_config.shuffle_data,
"include_community_rank": context_config.include_community_rank,
"min_community_rank": context_config.min_community_rank,
"community_rank_name": context_config.community_rank_name,
"include_community_weight": context_config.include_community_weight,
"community_weight_name": context_config.community_weight_name,
"normalize_community_weight": context_config.normalize_community_weight,
"max_tokens": context_config.max_data_tokens,
"context_name": "Reports",
}

map_llm_params = {
"max_tokens": mapreduce_config.map_max_tokens,
"temperature": mapreduce_config.map_temperature,
"response_format": {"type": "json_object"},
}

reduce_llm_params = {
"max_tokens": mapreduce_config.reduce_max_tokens,
"temperature": mapreduce_config.reduce_temperature,
}

self._search_engine = GlobalSearch(
llm=self._llm,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=context_config.max_data_tokens,
map_llm_params=map_llm_params,
reduce_llm_params=reduce_llm_params,
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
json_mode=mapreduce_config.json_mode,
context_builder_params=context_builder_params,
concurrent_coroutines=32,
response_type=mapreduce_config.response_type,
)

async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
result = await self._search_engine.asearch(args.query)
assert isinstance(result.response, str), "Expected response to be a string"
return GlobalSearchToolReturn(answer=result.response)

@classmethod
def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool":
"""Create a GlobalSearchTool instance from GraphRAG settings file.
Args:
settings_path: Path to the GraphRAG settings.yaml file
Returns:
An initialized GlobalSearchTool instance
"""
# Load GraphRAG config
config = load_config_from_file(settings_path)

# Initialize token encoder
token_encoder = tiktoken.get_encoding(config.encoding_model)

# Initialize LLM using graphrag's get_client
llm = get_llm(config)

# Create data config from storage paths
data_config = DataConfig(
input_dir=str(Path(config.storage.base_dir)),
)

return cls(
token_encoder=token_encoder,
llm=llm,
data_config=data_config,
context_config=_default_context_config,
mapreduce_config=_default_mapreduce_config,
)
Loading

0 comments on commit bb6c9fe

Please sign in to comment.