|
| 1 | +import hashlib |
| 2 | +import json |
| 3 | +import warnings |
| 4 | +from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast |
| 5 | + |
| 6 | +from autogen_core import CacheStore, CancellationToken |
| 7 | +from autogen_core.models import ( |
| 8 | + ChatCompletionClient, |
| 9 | + CreateResult, |
| 10 | + LLMMessage, |
| 11 | + ModelCapabilities, # type: ignore |
| 12 | + ModelInfo, |
| 13 | + RequestUsage, |
| 14 | +) |
| 15 | +from autogen_core.tools import Tool, ToolSchema |
| 16 | + |
| 17 | +CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]] |
| 18 | + |
| 19 | + |
| 20 | +class ChatCompletionCache(ChatCompletionClient): |
| 21 | + """ |
| 22 | + A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches |
| 23 | + creation results from an underlying client. |
| 24 | + Cache hits do not contribute to token usage of the original client. |
| 25 | +
|
| 26 | + Typical Usage: |
| 27 | +
|
| 28 | + Lets use caching on disk with `openai` client as an example. |
| 29 | + First install `autogen-ext` with the required packages: |
| 30 | +
|
| 31 | + .. code-block:: bash |
| 32 | +
|
| 33 | + pip install -U "autogen-ext[openai, diskcache]" |
| 34 | +
|
| 35 | + And use it as: |
| 36 | +
|
| 37 | + .. code-block:: python |
| 38 | +
|
| 39 | + import asyncio |
| 40 | + import tempfile |
| 41 | +
|
| 42 | + from autogen_core.models import UserMessage |
| 43 | + from autogen_ext.models.openai import OpenAIChatCompletionClient |
| 44 | + from autogen_ext.models.cache import ChatCompletionCache, CHAT_CACHE_VALUE_TYPE |
| 45 | + from autogen_ext.cache_store.diskcache import DiskCacheStore |
| 46 | + from diskcache import Cache |
| 47 | +
|
| 48 | +
|
| 49 | + async def main(): |
| 50 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 51 | + # Initialize the original client |
| 52 | + openai_model_client = OpenAIChatCompletionClient(model="gpt-4o") |
| 53 | +
|
| 54 | + # Then initialize the CacheStore, in this case with diskcache.Cache. |
| 55 | + # You can also use redis like: |
| 56 | + # from autogen_ext.cache_store.redis import RedisStore |
| 57 | + # import redis |
| 58 | + # redis_instance = redis.Redis() |
| 59 | + # cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance) |
| 60 | + cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname)) |
| 61 | + cache_client = ChatCompletionCache(openai_model_client, cache_store) |
| 62 | +
|
| 63 | + response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")]) |
| 64 | + print(response) # Should print response from OpenAI |
| 65 | + response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")]) |
| 66 | + print(response) # Should print cached response |
| 67 | +
|
| 68 | +
|
| 69 | + asyncio.run(main()) |
| 70 | +
|
| 71 | + You can now use the `cached_client` as you would the original client, but with caching enabled. |
| 72 | +
|
| 73 | + Args: |
| 74 | + client (ChatCompletionClient): The original ChatCompletionClient to wrap. |
| 75 | + store (CacheStore): A store object that implements get and set methods. |
| 76 | + The user is responsible for managing the store's lifecycle & clearing it (if needed). |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__(self, client: ChatCompletionClient, store: CacheStore[CHAT_CACHE_VALUE_TYPE]): |
| 80 | + self.client = client |
| 81 | + self.store = store |
| 82 | + |
| 83 | + def _check_cache( |
| 84 | + self, |
| 85 | + messages: Sequence[LLMMessage], |
| 86 | + tools: Sequence[Tool | ToolSchema], |
| 87 | + json_output: Optional[bool], |
| 88 | + extra_create_args: Mapping[str, Any], |
| 89 | + ) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]: |
| 90 | + """ |
| 91 | + Helper function to check the cache for a result. |
| 92 | + Returns a tuple of (cached_result, cache_key). |
| 93 | + """ |
| 94 | + |
| 95 | + data = { |
| 96 | + "messages": [message.model_dump() for message in messages], |
| 97 | + "tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools], |
| 98 | + "json_output": json_output, |
| 99 | + "extra_create_args": extra_create_args, |
| 100 | + } |
| 101 | + serialized_data = json.dumps(data, sort_keys=True) |
| 102 | + cache_key = hashlib.sha256(serialized_data.encode()).hexdigest() |
| 103 | + |
| 104 | + cached_result = cast(Optional[CreateResult], self.store.get(cache_key)) |
| 105 | + if cached_result is not None: |
| 106 | + return cached_result, cache_key |
| 107 | + |
| 108 | + return None, cache_key |
| 109 | + |
| 110 | + async def create( |
| 111 | + self, |
| 112 | + messages: Sequence[LLMMessage], |
| 113 | + *, |
| 114 | + tools: Sequence[Tool | ToolSchema] = [], |
| 115 | + json_output: Optional[bool] = None, |
| 116 | + extra_create_args: Mapping[str, Any] = {}, |
| 117 | + cancellation_token: Optional[CancellationToken] = None, |
| 118 | + ) -> CreateResult: |
| 119 | + """ |
| 120 | + Cached version of ChatCompletionClient.create. |
| 121 | + If the result of a call to create has been cached, it will be returned immediately |
| 122 | + without invoking the underlying client. |
| 123 | +
|
| 124 | + NOTE: cancellation_token is ignored for cached results. |
| 125 | + """ |
| 126 | + cached_result, cache_key = self._check_cache(messages, tools, json_output, extra_create_args) |
| 127 | + if cached_result: |
| 128 | + assert isinstance(cached_result, CreateResult) |
| 129 | + cached_result.cached = True |
| 130 | + return cached_result |
| 131 | + |
| 132 | + result = await self.client.create( |
| 133 | + messages, |
| 134 | + tools=tools, |
| 135 | + json_output=json_output, |
| 136 | + extra_create_args=extra_create_args, |
| 137 | + cancellation_token=cancellation_token, |
| 138 | + ) |
| 139 | + self.store.set(cache_key, result) |
| 140 | + return result |
| 141 | + |
| 142 | + def create_stream( |
| 143 | + self, |
| 144 | + messages: Sequence[LLMMessage], |
| 145 | + *, |
| 146 | + tools: Sequence[Tool | ToolSchema] = [], |
| 147 | + json_output: Optional[bool] = None, |
| 148 | + extra_create_args: Mapping[str, Any] = {}, |
| 149 | + cancellation_token: Optional[CancellationToken] = None, |
| 150 | + ) -> AsyncGenerator[Union[str, CreateResult], None]: |
| 151 | + """ |
| 152 | + Cached version of ChatCompletionClient.create_stream. |
| 153 | + If the result of a call to create_stream has been cached, it will be returned |
| 154 | + without streaming from the underlying client. |
| 155 | +
|
| 156 | + NOTE: cancellation_token is ignored for cached results. |
| 157 | + """ |
| 158 | + |
| 159 | + async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]: |
| 160 | + cached_result, cache_key = self._check_cache( |
| 161 | + messages, |
| 162 | + tools, |
| 163 | + json_output, |
| 164 | + extra_create_args, |
| 165 | + ) |
| 166 | + if cached_result: |
| 167 | + assert isinstance(cached_result, list) |
| 168 | + for result in cached_result: |
| 169 | + if isinstance(result, CreateResult): |
| 170 | + result.cached = True |
| 171 | + yield result |
| 172 | + return |
| 173 | + |
| 174 | + result_stream = self.client.create_stream( |
| 175 | + messages, |
| 176 | + tools=tools, |
| 177 | + json_output=json_output, |
| 178 | + extra_create_args=extra_create_args, |
| 179 | + cancellation_token=cancellation_token, |
| 180 | + ) |
| 181 | + |
| 182 | + output_results: List[Union[str, CreateResult]] = [] |
| 183 | + self.store.set(cache_key, output_results) |
| 184 | + |
| 185 | + async for result in result_stream: |
| 186 | + output_results.append(result) |
| 187 | + yield result |
| 188 | + |
| 189 | + return _generator() |
| 190 | + |
| 191 | + def actual_usage(self) -> RequestUsage: |
| 192 | + return self.client.actual_usage() |
| 193 | + |
| 194 | + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: |
| 195 | + return self.client.count_tokens(messages, tools=tools) |
| 196 | + |
| 197 | + @property |
| 198 | + def capabilities(self) -> ModelCapabilities: # type: ignore |
| 199 | + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) |
| 200 | + return self.client.capabilities |
| 201 | + |
| 202 | + @property |
| 203 | + def model_info(self) -> ModelInfo: |
| 204 | + return self.client.model_info |
| 205 | + |
| 206 | + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: |
| 207 | + return self.client.remaining_tokens(messages, tools=tools) |
| 208 | + |
| 209 | + def total_usage(self) -> RequestUsage: |
| 210 | + return self.client.total_usage() |
0 commit comments