|
| 1 | +# mypy: disable-error-code="no-any-unimported,misc" |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +import tiktoken |
| 6 | +from autogen_core import CancellationToken |
| 7 | +from autogen_core.tools import BaseTool |
| 8 | +from graphrag.config.config_file_loader import load_config_from_file |
| 9 | +from graphrag.query.indexer_adapters import ( |
| 10 | + read_indexer_communities, |
| 11 | + read_indexer_entities, |
| 12 | + read_indexer_reports, |
| 13 | +) |
| 14 | +from graphrag.query.llm.base import BaseLLM |
| 15 | +from graphrag.query.llm.get_client import get_llm |
| 16 | +from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext |
| 17 | +from graphrag.query.structured_search.global_search.search import GlobalSearch |
| 18 | +from pydantic import BaseModel, Field |
| 19 | + |
| 20 | +from ._config import GlobalContextConfig as ContextConfig |
| 21 | +from ._config import GlobalDataConfig as DataConfig |
| 22 | +from ._config import MapReduceConfig |
| 23 | + |
| 24 | +_default_context_config = ContextConfig() |
| 25 | +_default_mapreduce_config = MapReduceConfig() |
| 26 | + |
| 27 | + |
| 28 | +class GlobalSearchToolArgs(BaseModel): |
| 29 | + query: str = Field(..., description="The user query to perform global search on.") |
| 30 | + |
| 31 | + |
| 32 | +class GlobalSearchToolReturn(BaseModel): |
| 33 | + answer: str |
| 34 | + |
| 35 | + |
| 36 | +class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]): |
| 37 | + """Enables running GraphRAG global search queries as an AutoGen tool. |
| 38 | +
|
| 39 | + This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework. |
| 40 | + The search combines graph-based document relationships with semantic embeddings to find relevant information. |
| 41 | +
|
| 42 | + .. note:: |
| 43 | + This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package. |
| 44 | +
|
| 45 | + To install: |
| 46 | +
|
| 47 | + .. code-block:: bash |
| 48 | +
|
| 49 | + pip install -U "autogen-agentchat" "autogen-ext[graphrag]" |
| 50 | +
|
| 51 | + Before using this tool, you must complete the GraphRAG setup and indexing process: |
| 52 | +
|
| 53 | + 1. Follow the GraphRAG documentation to initialize your project and settings |
| 54 | + 2. Configure and tune your prompts for the specific use case |
| 55 | + 3. Run the indexing process to generate the required data files |
| 56 | + 4. Ensure you have the settings.yaml file from the setup process |
| 57 | +
|
| 58 | + Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/) |
| 59 | + for detailed instructions on completing these prerequisite steps. |
| 60 | +
|
| 61 | + Example usage with AssistantAgent: |
| 62 | +
|
| 63 | + .. code-block:: python |
| 64 | +
|
| 65 | + import asyncio |
| 66 | + from autogen_ext.models.openai import OpenAIChatCompletionClient |
| 67 | + from autogen_agentchat.ui import Console |
| 68 | + from autogen_ext.tools.graphrag import GlobalSearchTool |
| 69 | + from autogen_agentchat.agents import AssistantAgent |
| 70 | +
|
| 71 | +
|
| 72 | + async def main(): |
| 73 | + # Initialize the OpenAI client |
| 74 | + openai_client = OpenAIChatCompletionClient( |
| 75 | + model="gpt-4o-mini", |
| 76 | + api_key="<api-key>", |
| 77 | + ) |
| 78 | +
|
| 79 | + # Set up global search tool |
| 80 | + global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml") |
| 81 | +
|
| 82 | + # Create assistant agent with the global search tool |
| 83 | + assistant_agent = AssistantAgent( |
| 84 | + name="search_assistant", |
| 85 | + tools=[global_tool], |
| 86 | + model_client=openai_client, |
| 87 | + system_message=( |
| 88 | + "You are a tool selector AI assistant using the GraphRAG framework. " |
| 89 | + "Your primary task is to determine the appropriate search tool to call based on the user's query. " |
| 90 | + "For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function." |
| 91 | + ), |
| 92 | + ) |
| 93 | +
|
| 94 | + # Run a sample query |
| 95 | + query = "What is the overall sentiment of the community reports?" |
| 96 | + await Console(assistant_agent.run_stream(task=query)) |
| 97 | +
|
| 98 | +
|
| 99 | + if __name__ == "__main__": |
| 100 | + asyncio.run(main()) |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__( |
| 104 | + self, |
| 105 | + token_encoder: tiktoken.Encoding, |
| 106 | + llm: BaseLLM, |
| 107 | + data_config: DataConfig, |
| 108 | + context_config: ContextConfig = _default_context_config, |
| 109 | + mapreduce_config: MapReduceConfig = _default_mapreduce_config, |
| 110 | + ): |
| 111 | + super().__init__( |
| 112 | + args_type=GlobalSearchToolArgs, |
| 113 | + return_type=GlobalSearchToolReturn, |
| 114 | + name="global_search_tool", |
| 115 | + description="Perform a global search with given parameters using graphrag.", |
| 116 | + ) |
| 117 | + # Use the provided LLM |
| 118 | + self._llm = llm |
| 119 | + |
| 120 | + # Load parquet files |
| 121 | + community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore |
| 122 | + entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore |
| 123 | + report_df: pd.DataFrame = pd.read_parquet( # type: ignore |
| 124 | + f"{data_config.input_dir}/{data_config.community_report_table}.parquet" |
| 125 | + ) |
| 126 | + entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore |
| 127 | + f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet" |
| 128 | + ) |
| 129 | + |
| 130 | + communities = read_indexer_communities(community_df, entity_df, report_df) |
| 131 | + reports = read_indexer_reports(report_df, entity_df, data_config.community_level) |
| 132 | + entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level) |
| 133 | + |
| 134 | + context_builder = GlobalCommunityContext( |
| 135 | + community_reports=reports, |
| 136 | + communities=communities, |
| 137 | + entities=entities, |
| 138 | + token_encoder=token_encoder, |
| 139 | + ) |
| 140 | + |
| 141 | + context_builder_params = { |
| 142 | + "use_community_summary": context_config.use_community_summary, |
| 143 | + "shuffle_data": context_config.shuffle_data, |
| 144 | + "include_community_rank": context_config.include_community_rank, |
| 145 | + "min_community_rank": context_config.min_community_rank, |
| 146 | + "community_rank_name": context_config.community_rank_name, |
| 147 | + "include_community_weight": context_config.include_community_weight, |
| 148 | + "community_weight_name": context_config.community_weight_name, |
| 149 | + "normalize_community_weight": context_config.normalize_community_weight, |
| 150 | + "max_tokens": context_config.max_data_tokens, |
| 151 | + "context_name": "Reports", |
| 152 | + } |
| 153 | + |
| 154 | + map_llm_params = { |
| 155 | + "max_tokens": mapreduce_config.map_max_tokens, |
| 156 | + "temperature": mapreduce_config.map_temperature, |
| 157 | + "response_format": {"type": "json_object"}, |
| 158 | + } |
| 159 | + |
| 160 | + reduce_llm_params = { |
| 161 | + "max_tokens": mapreduce_config.reduce_max_tokens, |
| 162 | + "temperature": mapreduce_config.reduce_temperature, |
| 163 | + } |
| 164 | + |
| 165 | + self._search_engine = GlobalSearch( |
| 166 | + llm=self._llm, |
| 167 | + context_builder=context_builder, |
| 168 | + token_encoder=token_encoder, |
| 169 | + max_data_tokens=context_config.max_data_tokens, |
| 170 | + map_llm_params=map_llm_params, |
| 171 | + reduce_llm_params=reduce_llm_params, |
| 172 | + allow_general_knowledge=mapreduce_config.allow_general_knowledge, |
| 173 | + json_mode=mapreduce_config.json_mode, |
| 174 | + context_builder_params=context_builder_params, |
| 175 | + concurrent_coroutines=32, |
| 176 | + response_type=mapreduce_config.response_type, |
| 177 | + ) |
| 178 | + |
| 179 | + async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn: |
| 180 | + result = await self._search_engine.asearch(args.query) |
| 181 | + assert isinstance(result.response, str), "Expected response to be a string" |
| 182 | + return GlobalSearchToolReturn(answer=result.response) |
| 183 | + |
| 184 | + @classmethod |
| 185 | + def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool": |
| 186 | + """Create a GlobalSearchTool instance from GraphRAG settings file. |
| 187 | +
|
| 188 | + Args: |
| 189 | + settings_path: Path to the GraphRAG settings.yaml file |
| 190 | +
|
| 191 | + Returns: |
| 192 | + An initialized GlobalSearchTool instance |
| 193 | + """ |
| 194 | + # Load GraphRAG config |
| 195 | + config = load_config_from_file(settings_path) |
| 196 | + |
| 197 | + # Initialize token encoder |
| 198 | + token_encoder = tiktoken.get_encoding(config.encoding_model) |
| 199 | + |
| 200 | + # Initialize LLM using graphrag's get_client |
| 201 | + llm = get_llm(config) |
| 202 | + |
| 203 | + # Create data config from storage paths |
| 204 | + data_config = DataConfig( |
| 205 | + input_dir=str(Path(config.storage.base_dir)), |
| 206 | + ) |
| 207 | + |
| 208 | + return cls( |
| 209 | + token_encoder=token_encoder, |
| 210 | + llm=llm, |
| 211 | + data_config=data_config, |
| 212 | + context_config=_default_context_config, |
| 213 | + mapreduce_config=_default_mapreduce_config, |
| 214 | + ) |
0 commit comments