-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into 4o-2024-11-20
- Loading branch information
Showing
27 changed files
with
2,581 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ classifiers = [ | |
] | ||
dependencies = [ | ||
"autogen-core==0.4.1", | ||
"aioconsole>=0.8.1" | ||
] | ||
|
||
[tool.ruff] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
.../packages/autogen-core/docs/src/reference/python/autogen_ext.tools.graphrag.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
autogen\_ext.tools.graphrag | ||
=========================== | ||
|
||
|
||
.. automodule:: autogen_ext.tools.graphrag | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
python/packages/autogen-ext/src/autogen_ext/tools/graphrag/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ._config import ( | ||
GlobalContextConfig, | ||
GlobalDataConfig, | ||
LocalContextConfig, | ||
LocalDataConfig, | ||
MapReduceConfig, | ||
SearchConfig, | ||
) | ||
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn | ||
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn | ||
|
||
__all__ = [ | ||
"GlobalSearchTool", | ||
"LocalSearchTool", | ||
"GlobalDataConfig", | ||
"LocalDataConfig", | ||
"GlobalContextConfig", | ||
"GlobalSearchToolArgs", | ||
"GlobalSearchToolReturn", | ||
"LocalContextConfig", | ||
"LocalSearchToolArgs", | ||
"LocalSearchToolReturn", | ||
"MapReduceConfig", | ||
"SearchConfig", | ||
] |
59 changes: 59 additions & 0 deletions
59
python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class DataConfig(BaseModel): | ||
input_dir: str | ||
entity_table: str = "create_final_nodes" | ||
entity_embedding_table: str = "create_final_entities" | ||
community_level: int = 2 | ||
|
||
|
||
class GlobalDataConfig(DataConfig): | ||
community_table: str = "create_final_communities" | ||
community_report_table: str = "create_final_community_reports" | ||
|
||
|
||
class LocalDataConfig(DataConfig): | ||
relationship_table: str = "create_final_relationships" | ||
text_unit_table: str = "create_final_text_units" | ||
|
||
|
||
class ContextConfig(BaseModel): | ||
max_data_tokens: int = 8000 | ||
|
||
|
||
class GlobalContextConfig(ContextConfig): | ||
use_community_summary: bool = False | ||
shuffle_data: bool = True | ||
include_community_rank: bool = True | ||
min_community_rank: int = 0 | ||
community_rank_name: str = "rank" | ||
include_community_weight: bool = True | ||
community_weight_name: str = "occurrence weight" | ||
normalize_community_weight: bool = True | ||
max_data_tokens: int = 12000 | ||
|
||
|
||
class LocalContextConfig(ContextConfig): | ||
text_unit_prop: float = 0.5 | ||
community_prop: float = 0.25 | ||
include_entity_rank: bool = True | ||
rank_description: str = "number of relationships" | ||
include_relationship_weight: bool = True | ||
relationship_ranking_attribute: str = "rank" | ||
|
||
|
||
class MapReduceConfig(BaseModel): | ||
map_max_tokens: int = 1000 | ||
map_temperature: float = 0.0 | ||
reduce_max_tokens: int = 2000 | ||
reduce_temperature: float = 0.0 | ||
allow_general_knowledge: bool = False | ||
json_mode: bool = False | ||
response_type: str = "multiple paragraphs" | ||
|
||
|
||
class SearchConfig(BaseModel): | ||
max_tokens: int = 1500 | ||
temperature: float = 0.0 | ||
response_type: str = "multiple paragraphs" |
214 changes: 214 additions & 0 deletions
214
python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# mypy: disable-error-code="no-any-unimported,misc" | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import tiktoken | ||
from autogen_core import CancellationToken | ||
from autogen_core.tools import BaseTool | ||
from graphrag.config.config_file_loader import load_config_from_file | ||
from graphrag.query.indexer_adapters import ( | ||
read_indexer_communities, | ||
read_indexer_entities, | ||
read_indexer_reports, | ||
) | ||
from graphrag.query.llm.base import BaseLLM | ||
from graphrag.query.llm.get_client import get_llm | ||
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext | ||
from graphrag.query.structured_search.global_search.search import GlobalSearch | ||
from pydantic import BaseModel, Field | ||
|
||
from ._config import GlobalContextConfig as ContextConfig | ||
from ._config import GlobalDataConfig as DataConfig | ||
from ._config import MapReduceConfig | ||
|
||
_default_context_config = ContextConfig() | ||
_default_mapreduce_config = MapReduceConfig() | ||
|
||
|
||
class GlobalSearchToolArgs(BaseModel): | ||
query: str = Field(..., description="The user query to perform global search on.") | ||
|
||
|
||
class GlobalSearchToolReturn(BaseModel): | ||
answer: str | ||
|
||
|
||
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]): | ||
"""Enables running GraphRAG global search queries as an AutoGen tool. | ||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework. | ||
The search combines graph-based document relationships with semantic embeddings to find relevant information. | ||
.. note:: | ||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package. | ||
To install: | ||
.. code-block:: bash | ||
pip install -U "autogen-agentchat" "autogen-ext[graphrag]" | ||
Before using this tool, you must complete the GraphRAG setup and indexing process: | ||
1. Follow the GraphRAG documentation to initialize your project and settings | ||
2. Configure and tune your prompts for the specific use case | ||
3. Run the indexing process to generate the required data files | ||
4. Ensure you have the settings.yaml file from the setup process | ||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/) | ||
for detailed instructions on completing these prerequisite steps. | ||
Example usage with AssistantAgent: | ||
.. code-block:: python | ||
import asyncio | ||
from autogen_ext.models.openai import OpenAIChatCompletionClient | ||
from autogen_agentchat.ui import Console | ||
from autogen_ext.tools.graphrag import GlobalSearchTool | ||
from autogen_agentchat.agents import AssistantAgent | ||
async def main(): | ||
# Initialize the OpenAI client | ||
openai_client = OpenAIChatCompletionClient( | ||
model="gpt-4o-mini", | ||
api_key="<api-key>", | ||
) | ||
# Set up global search tool | ||
global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml") | ||
# Create assistant agent with the global search tool | ||
assistant_agent = AssistantAgent( | ||
name="search_assistant", | ||
tools=[global_tool], | ||
model_client=openai_client, | ||
system_message=( | ||
"You are a tool selector AI assistant using the GraphRAG framework. " | ||
"Your primary task is to determine the appropriate search tool to call based on the user's query. " | ||
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function." | ||
), | ||
) | ||
# Run a sample query | ||
query = "What is the overall sentiment of the community reports?" | ||
await Console(assistant_agent.run_stream(task=query)) | ||
if __name__ == "__main__": | ||
asyncio.run(main()) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
token_encoder: tiktoken.Encoding, | ||
llm: BaseLLM, | ||
data_config: DataConfig, | ||
context_config: ContextConfig = _default_context_config, | ||
mapreduce_config: MapReduceConfig = _default_mapreduce_config, | ||
): | ||
super().__init__( | ||
args_type=GlobalSearchToolArgs, | ||
return_type=GlobalSearchToolReturn, | ||
name="global_search_tool", | ||
description="Perform a global search with given parameters using graphrag.", | ||
) | ||
# Use the provided LLM | ||
self._llm = llm | ||
|
||
# Load parquet files | ||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore | ||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore | ||
report_df: pd.DataFrame = pd.read_parquet( # type: ignore | ||
f"{data_config.input_dir}/{data_config.community_report_table}.parquet" | ||
) | ||
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore | ||
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet" | ||
) | ||
|
||
communities = read_indexer_communities(community_df, entity_df, report_df) | ||
reports = read_indexer_reports(report_df, entity_df, data_config.community_level) | ||
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level) | ||
|
||
context_builder = GlobalCommunityContext( | ||
community_reports=reports, | ||
communities=communities, | ||
entities=entities, | ||
token_encoder=token_encoder, | ||
) | ||
|
||
context_builder_params = { | ||
"use_community_summary": context_config.use_community_summary, | ||
"shuffle_data": context_config.shuffle_data, | ||
"include_community_rank": context_config.include_community_rank, | ||
"min_community_rank": context_config.min_community_rank, | ||
"community_rank_name": context_config.community_rank_name, | ||
"include_community_weight": context_config.include_community_weight, | ||
"community_weight_name": context_config.community_weight_name, | ||
"normalize_community_weight": context_config.normalize_community_weight, | ||
"max_tokens": context_config.max_data_tokens, | ||
"context_name": "Reports", | ||
} | ||
|
||
map_llm_params = { | ||
"max_tokens": mapreduce_config.map_max_tokens, | ||
"temperature": mapreduce_config.map_temperature, | ||
"response_format": {"type": "json_object"}, | ||
} | ||
|
||
reduce_llm_params = { | ||
"max_tokens": mapreduce_config.reduce_max_tokens, | ||
"temperature": mapreduce_config.reduce_temperature, | ||
} | ||
|
||
self._search_engine = GlobalSearch( | ||
llm=self._llm, | ||
context_builder=context_builder, | ||
token_encoder=token_encoder, | ||
max_data_tokens=context_config.max_data_tokens, | ||
map_llm_params=map_llm_params, | ||
reduce_llm_params=reduce_llm_params, | ||
allow_general_knowledge=mapreduce_config.allow_general_knowledge, | ||
json_mode=mapreduce_config.json_mode, | ||
context_builder_params=context_builder_params, | ||
concurrent_coroutines=32, | ||
response_type=mapreduce_config.response_type, | ||
) | ||
|
||
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn: | ||
result = await self._search_engine.asearch(args.query) | ||
assert isinstance(result.response, str), "Expected response to be a string" | ||
return GlobalSearchToolReturn(answer=result.response) | ||
|
||
@classmethod | ||
def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool": | ||
"""Create a GlobalSearchTool instance from GraphRAG settings file. | ||
Args: | ||
settings_path: Path to the GraphRAG settings.yaml file | ||
Returns: | ||
An initialized GlobalSearchTool instance | ||
""" | ||
# Load GraphRAG config | ||
config = load_config_from_file(settings_path) | ||
|
||
# Initialize token encoder | ||
token_encoder = tiktoken.get_encoding(config.encoding_model) | ||
|
||
# Initialize LLM using graphrag's get_client | ||
llm = get_llm(config) | ||
|
||
# Create data config from storage paths | ||
data_config = DataConfig( | ||
input_dir=str(Path(config.storage.base_dir)), | ||
) | ||
|
||
return cls( | ||
token_encoder=token_encoder, | ||
llm=llm, | ||
data_config=data_config, | ||
context_config=_default_context_config, | ||
mapreduce_config=_default_mapreduce_config, | ||
) |
Oops, something went wrong.