Skip to content

Commit 95bd514

Browse files
lspinheirolpinheiromsekzhu
authoredJan 15, 2025··
Graphrag integration (#4612)
* add initial global search draft * add graphrag dep * fix local search embedding * linting * add from config constructor * remove draft notebook * update config factory and add docstrings * add graphrag sample * add sample prompts * update readme * update deps * Add API docs * Update python/samples/agentchat_graphrag/requirements.txt * Update python/samples/agentchat_graphrag/requirements.txt * update docstrings with snippet and doc ref * lint * improve set up instructions in docstring * lint * update lock * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py Co-authored-by: Eric Zhu <[email protected]> * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py Co-authored-by: Eric Zhu <[email protected]> * add unit tests * update lock * update uv lock * add docstring newlines * stubs and typing on graphrag tests * fix docstrings * fix mypy error * + linting and type fixes * type fix graphrag sample * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py Co-authored-by: Eric Zhu <[email protected]> * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py Co-authored-by: Eric Zhu <[email protected]> * Update python/samples/agentchat_graphrag/requirements.txt Co-authored-by: Eric Zhu <[email protected]> * update overrides * fix docstring client imports * additional docstring fix * add docstring missing import * use openai and fix db path * use console for displaying messages * add model config and gitignore * update readme * lint * Update python/samples/agentchat_graphrag/README.md * Update python/samples/agentchat_graphrag/README.md * Comment remaining azure config --------- Co-authored-by: Leonardo Pinheiro <[email protected]> Co-authored-by: Eric Zhu <[email protected]>
1 parent 8efe0c4 commit 95bd514

22 files changed

+2365
-53
lines changed
 

‎python/packages/autogen-core/docs/src/reference/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
5151
python/autogen_ext.models.openai
5252
python/autogen_ext.models.replay
5353
python/autogen_ext.tools.langchain
54+
python/autogen_ext.tools.graphrag
5455
python/autogen_ext.tools.code_execution
5556
python/autogen_ext.code_executors.local
5657
python/autogen_ext.code_executors.docker
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
autogen\_ext.tools.graphrag
2+
===========================
3+
4+
5+
.. automodule:: autogen_ext.tools.graphrag
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:

‎python/packages/autogen-ext/pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ file-surfer = [
2727
"autogen-agentchat==0.4.1",
2828
"markitdown>=0.0.1a2",
2929
]
30+
graphrag = ["graphrag>=1.0.1"]
3031
web-surfer = [
3132
"autogen-agentchat==0.4.1",
3233
"playwright>=1.48.0",
@@ -57,6 +58,7 @@ packages = ["src/autogen_ext"]
5758
dev = [
5859
"autogen_test_utils",
5960
"langchain-experimental",
61+
"pandas-stubs>=2.2.3.241126",
6062
]
6163

6264
[tool.ruff]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from ._config import (
2+
GlobalContextConfig,
3+
GlobalDataConfig,
4+
LocalContextConfig,
5+
LocalDataConfig,
6+
MapReduceConfig,
7+
SearchConfig,
8+
)
9+
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
10+
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
11+
12+
__all__ = [
13+
"GlobalSearchTool",
14+
"LocalSearchTool",
15+
"GlobalDataConfig",
16+
"LocalDataConfig",
17+
"GlobalContextConfig",
18+
"GlobalSearchToolArgs",
19+
"GlobalSearchToolReturn",
20+
"LocalContextConfig",
21+
"LocalSearchToolArgs",
22+
"LocalSearchToolReturn",
23+
"MapReduceConfig",
24+
"SearchConfig",
25+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from pydantic import BaseModel
2+
3+
4+
class DataConfig(BaseModel):
5+
input_dir: str
6+
entity_table: str = "create_final_nodes"
7+
entity_embedding_table: str = "create_final_entities"
8+
community_level: int = 2
9+
10+
11+
class GlobalDataConfig(DataConfig):
12+
community_table: str = "create_final_communities"
13+
community_report_table: str = "create_final_community_reports"
14+
15+
16+
class LocalDataConfig(DataConfig):
17+
relationship_table: str = "create_final_relationships"
18+
text_unit_table: str = "create_final_text_units"
19+
20+
21+
class ContextConfig(BaseModel):
22+
max_data_tokens: int = 8000
23+
24+
25+
class GlobalContextConfig(ContextConfig):
26+
use_community_summary: bool = False
27+
shuffle_data: bool = True
28+
include_community_rank: bool = True
29+
min_community_rank: int = 0
30+
community_rank_name: str = "rank"
31+
include_community_weight: bool = True
32+
community_weight_name: str = "occurrence weight"
33+
normalize_community_weight: bool = True
34+
max_data_tokens: int = 12000
35+
36+
37+
class LocalContextConfig(ContextConfig):
38+
text_unit_prop: float = 0.5
39+
community_prop: float = 0.25
40+
include_entity_rank: bool = True
41+
rank_description: str = "number of relationships"
42+
include_relationship_weight: bool = True
43+
relationship_ranking_attribute: str = "rank"
44+
45+
46+
class MapReduceConfig(BaseModel):
47+
map_max_tokens: int = 1000
48+
map_temperature: float = 0.0
49+
reduce_max_tokens: int = 2000
50+
reduce_temperature: float = 0.0
51+
allow_general_knowledge: bool = False
52+
json_mode: bool = False
53+
response_type: str = "multiple paragraphs"
54+
55+
56+
class SearchConfig(BaseModel):
57+
max_tokens: int = 1500
58+
temperature: float = 0.0
59+
response_type: str = "multiple paragraphs"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# mypy: disable-error-code="no-any-unimported,misc"
2+
from pathlib import Path
3+
4+
import pandas as pd
5+
import tiktoken
6+
from autogen_core import CancellationToken
7+
from autogen_core.tools import BaseTool
8+
from graphrag.config.config_file_loader import load_config_from_file
9+
from graphrag.query.indexer_adapters import (
10+
read_indexer_communities,
11+
read_indexer_entities,
12+
read_indexer_reports,
13+
)
14+
from graphrag.query.llm.base import BaseLLM
15+
from graphrag.query.llm.get_client import get_llm
16+
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
17+
from graphrag.query.structured_search.global_search.search import GlobalSearch
18+
from pydantic import BaseModel, Field
19+
20+
from ._config import GlobalContextConfig as ContextConfig
21+
from ._config import GlobalDataConfig as DataConfig
22+
from ._config import MapReduceConfig
23+
24+
_default_context_config = ContextConfig()
25+
_default_mapreduce_config = MapReduceConfig()
26+
27+
28+
class GlobalSearchToolArgs(BaseModel):
29+
query: str = Field(..., description="The user query to perform global search on.")
30+
31+
32+
class GlobalSearchToolReturn(BaseModel):
33+
answer: str
34+
35+
36+
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
37+
"""Enables running GraphRAG global search queries as an AutoGen tool.
38+
39+
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
40+
The search combines graph-based document relationships with semantic embeddings to find relevant information.
41+
42+
.. note::
43+
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
44+
45+
To install:
46+
47+
.. code-block:: bash
48+
49+
pip install -U "autogen-agentchat" "autogen-ext[graphrag]"
50+
51+
Before using this tool, you must complete the GraphRAG setup and indexing process:
52+
53+
1. Follow the GraphRAG documentation to initialize your project and settings
54+
2. Configure and tune your prompts for the specific use case
55+
3. Run the indexing process to generate the required data files
56+
4. Ensure you have the settings.yaml file from the setup process
57+
58+
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
59+
for detailed instructions on completing these prerequisite steps.
60+
61+
Example usage with AssistantAgent:
62+
63+
.. code-block:: python
64+
65+
import asyncio
66+
from autogen_ext.models.openai import OpenAIChatCompletionClient
67+
from autogen_agentchat.ui import Console
68+
from autogen_ext.tools.graphrag import GlobalSearchTool
69+
from autogen_agentchat.agents import AssistantAgent
70+
71+
72+
async def main():
73+
# Initialize the OpenAI client
74+
openai_client = OpenAIChatCompletionClient(
75+
model="gpt-4o-mini",
76+
api_key="<api-key>",
77+
)
78+
79+
# Set up global search tool
80+
global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml")
81+
82+
# Create assistant agent with the global search tool
83+
assistant_agent = AssistantAgent(
84+
name="search_assistant",
85+
tools=[global_tool],
86+
model_client=openai_client,
87+
system_message=(
88+
"You are a tool selector AI assistant using the GraphRAG framework. "
89+
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
90+
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
91+
),
92+
)
93+
94+
# Run a sample query
95+
query = "What is the overall sentiment of the community reports?"
96+
await Console(assistant_agent.run_stream(task=query))
97+
98+
99+
if __name__ == "__main__":
100+
asyncio.run(main())
101+
"""
102+
103+
def __init__(
104+
self,
105+
token_encoder: tiktoken.Encoding,
106+
llm: BaseLLM,
107+
data_config: DataConfig,
108+
context_config: ContextConfig = _default_context_config,
109+
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
110+
):
111+
super().__init__(
112+
args_type=GlobalSearchToolArgs,
113+
return_type=GlobalSearchToolReturn,
114+
name="global_search_tool",
115+
description="Perform a global search with given parameters using graphrag.",
116+
)
117+
# Use the provided LLM
118+
self._llm = llm
119+
120+
# Load parquet files
121+
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
122+
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
123+
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
124+
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
125+
)
126+
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
127+
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
128+
)
129+
130+
communities = read_indexer_communities(community_df, entity_df, report_df)
131+
reports = read_indexer_reports(report_df, entity_df, data_config.community_level)
132+
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level)
133+
134+
context_builder = GlobalCommunityContext(
135+
community_reports=reports,
136+
communities=communities,
137+
entities=entities,
138+
token_encoder=token_encoder,
139+
)
140+
141+
context_builder_params = {
142+
"use_community_summary": context_config.use_community_summary,
143+
"shuffle_data": context_config.shuffle_data,
144+
"include_community_rank": context_config.include_community_rank,
145+
"min_community_rank": context_config.min_community_rank,
146+
"community_rank_name": context_config.community_rank_name,
147+
"include_community_weight": context_config.include_community_weight,
148+
"community_weight_name": context_config.community_weight_name,
149+
"normalize_community_weight": context_config.normalize_community_weight,
150+
"max_tokens": context_config.max_data_tokens,
151+
"context_name": "Reports",
152+
}
153+
154+
map_llm_params = {
155+
"max_tokens": mapreduce_config.map_max_tokens,
156+
"temperature": mapreduce_config.map_temperature,
157+
"response_format": {"type": "json_object"},
158+
}
159+
160+
reduce_llm_params = {
161+
"max_tokens": mapreduce_config.reduce_max_tokens,
162+
"temperature": mapreduce_config.reduce_temperature,
163+
}
164+
165+
self._search_engine = GlobalSearch(
166+
llm=self._llm,
167+
context_builder=context_builder,
168+
token_encoder=token_encoder,
169+
max_data_tokens=context_config.max_data_tokens,
170+
map_llm_params=map_llm_params,
171+
reduce_llm_params=reduce_llm_params,
172+
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
173+
json_mode=mapreduce_config.json_mode,
174+
context_builder_params=context_builder_params,
175+
concurrent_coroutines=32,
176+
response_type=mapreduce_config.response_type,
177+
)
178+
179+
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
180+
result = await self._search_engine.asearch(args.query)
181+
assert isinstance(result.response, str), "Expected response to be a string"
182+
return GlobalSearchToolReturn(answer=result.response)
183+
184+
@classmethod
185+
def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool":
186+
"""Create a GlobalSearchTool instance from GraphRAG settings file.
187+
188+
Args:
189+
settings_path: Path to the GraphRAG settings.yaml file
190+
191+
Returns:
192+
An initialized GlobalSearchTool instance
193+
"""
194+
# Load GraphRAG config
195+
config = load_config_from_file(settings_path)
196+
197+
# Initialize token encoder
198+
token_encoder = tiktoken.get_encoding(config.encoding_model)
199+
200+
# Initialize LLM using graphrag's get_client
201+
llm = get_llm(config)
202+
203+
# Create data config from storage paths
204+
data_config = DataConfig(
205+
input_dir=str(Path(config.storage.base_dir)),
206+
)
207+
208+
return cls(
209+
token_encoder=token_encoder,
210+
llm=llm,
211+
data_config=data_config,
212+
context_config=_default_context_config,
213+
mapreduce_config=_default_mapreduce_config,
214+
)

0 commit comments

Comments
 (0)
Please sign in to comment.