Skip to content

Commit 8bd65c6

Browse files
authoredJan 16, 2025··
Add ChatCompletionCache along with AbstractStore for caching completions (#4924)
* Add ChatCompletionCache along with AbstractStore for caching completions * Addressing comments * Improve interface for cachestore * Improve documentation & revert protocol * Make cache store typed, and improve docs * remove unnecessary casts
1 parent 2e1a9c7 commit 8bd65c6

22 files changed

+802
-18
lines changed
 

‎python/packages/autogen-core/docs/src/reference/index.md

+3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ python/autogen_ext.agents.video_surfer
4848
python/autogen_ext.agents.video_surfer.tools
4949
python/autogen_ext.auth.azure
5050
python/autogen_ext.teams.magentic_one
51+
python/autogen_ext.models.cache
5152
python/autogen_ext.models.openai
5253
python/autogen_ext.models.replay
5354
python/autogen_ext.tools.langchain
@@ -56,5 +57,7 @@ python/autogen_ext.tools.code_execution
5657
python/autogen_ext.code_executors.local
5758
python/autogen_ext.code_executors.docker
5859
python/autogen_ext.code_executors.azure
60+
python/autogen_ext.cache_store.diskcache
61+
python/autogen_ext.cache_store.redis
5962
python/autogen_ext.runtimes.grpc
6063
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
autogen\_ext.cache_store.diskcache
2+
==================================
3+
4+
5+
.. automodule:: autogen_ext.cache_store.diskcache
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
autogen\_ext.cache_store.redis
2+
==============================
3+
4+
5+
.. automodule:: autogen_ext.cache_store.redis
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
autogen\_ext.models.cache
2+
=========================
3+
4+
5+
.. automodule:: autogen_ext.models.cache
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
autogen\_ext.models.replay
2-
==========================
3-
4-
5-
.. automodule:: autogen_ext.models.replay
6-
:members:
7-
:undoc-members:
8-
:show-inheritance:
1+
autogen\_ext.models.replay
2+
==========================
3+
4+
5+
.. automodule:: autogen_ext.models.replay
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:

‎python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
"source": [
77
"# Models\n",
88
"\n",
9-
"In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. "
9+
"In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. \n",
10+
"\n",
11+
"```{note}\n",
12+
"See {py:class}`~autogen_ext.models.cache.ChatCompletionCache` for a caching wrapper to use with the following clients.\n",
13+
"```"
1014
]
1115
},
1216
{

‎python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb

+86-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,13 @@
9696
"cell_type": "markdown",
9797
"metadata": {},
9898
"source": [
99-
"Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n",
99+
"Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n"
100+
]
101+
},
102+
{
103+
"cell_type": "markdown",
104+
"metadata": {},
105+
"source": [
100106
"\n",
101107
"\n",
102108
"### Streaming Response\n",
@@ -315,6 +321,84 @@
315321
"```"
316322
]
317323
},
324+
{
325+
"cell_type": "markdown",
326+
"metadata": {},
327+
"source": [
328+
"## Caching Wrapper\n",
329+
"\n",
330+
"`autogen_ext` implements {py:class}`~autogen_ext.models.cache.ChatCompletionCache` that can wrap any {py:class}`~autogen_core.models.ChatCompletionClient`. Using this wrapper avoids incurring token usage when querying the underlying client with the same prompt multiple times.\n",
331+
"\n",
332+
"{py:class}`~autogen_core.models.ChatCompletionCache` uses a {py:class}`~autogen_core.CacheStore` protocol. We have implemented some useful variants of {py:class}`~autogen_core.CacheStore` including {py:class}`~autogen_ext.cache_store.diskcache.DiskCacheStore` and {py:class}`~autogen_ext.cache_store.redis.RedisStore`.\n",
333+
"\n",
334+
"Here's an example of using `diskcache` for local caching:"
335+
]
336+
},
337+
{
338+
"cell_type": "code",
339+
"execution_count": null,
340+
"metadata": {},
341+
"outputs": [],
342+
"source": [
343+
"# pip install -U \"autogen-ext[openai, diskcache]\""
344+
]
345+
},
346+
{
347+
"cell_type": "code",
348+
"execution_count": null,
349+
"metadata": {},
350+
"outputs": [
351+
{
352+
"name": "stdout",
353+
"output_type": "stream",
354+
"text": [
355+
"True\n"
356+
]
357+
}
358+
],
359+
"source": [
360+
"import asyncio\n",
361+
"import tempfile\n",
362+
"\n",
363+
"from autogen_core.models import UserMessage\n",
364+
"from autogen_ext.cache_store.diskcache import DiskCacheStore\n",
365+
"from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache\n",
366+
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
367+
"from diskcache import Cache\n",
368+
"\n",
369+
"\n",
370+
"async def main() -> None:\n",
371+
" with tempfile.TemporaryDirectory() as tmpdirname:\n",
372+
" # Initialize the original client\n",
373+
" openai_model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
374+
"\n",
375+
" # Then initialize the CacheStore, in this case with diskcache.Cache.\n",
376+
" # You can also use redis like:\n",
377+
" # from autogen_ext.cache_store.redis import RedisStore\n",
378+
" # import redis\n",
379+
" # redis_instance = redis.Redis()\n",
380+
" # cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance)\n",
381+
" cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname))\n",
382+
" cache_client = ChatCompletionCache(openai_model_client, cache_store)\n",
383+
"\n",
384+
" response = await cache_client.create([UserMessage(content=\"Hello, how are you?\", source=\"user\")])\n",
385+
" print(response) # Should print response from OpenAI\n",
386+
" response = await cache_client.create([UserMessage(content=\"Hello, how are you?\", source=\"user\")])\n",
387+
" print(response) # Should print cached response\n",
388+
"\n",
389+
"\n",
390+
"asyncio.run(main())"
391+
]
392+
},
393+
{
394+
"cell_type": "markdown",
395+
"metadata": {},
396+
"source": [
397+
"Inspecting `cached_client.total_usage()` (or `model_client.total_usage()`) before and after a cached response should yield idential counts.\n",
398+
"\n",
399+
"Note that the caching is sensitive to the exact arguments provided to `cached_client.create` or `cached_client.create_stream`, so changing `tools` or `json_output` arguments might lead to a cache miss."
400+
]
401+
},
318402
{
319403
"cell_type": "markdown",
320404
"metadata": {},
@@ -615,7 +699,7 @@
615699
"name": "python",
616700
"nbconvert_exporter": "python",
617701
"pygments_lexer": "ipython3",
618-
"version": "3.12.7"
702+
"version": "3.12.1"
619703
}
620704
},
621705
"nbformat": 4,

‎python/packages/autogen-core/pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ dev = [
7272
"autogen_ext==0.4.3",
7373

7474
# Documentation tooling
75+
"diskcache",
76+
"redis",
7577
"sphinx-autobuild",
7678
]
7779

‎python/packages/autogen-core/src/autogen_core/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ._agent_runtime import AgentRuntime
1111
from ._agent_type import AgentType
1212
from ._base_agent import BaseAgent
13+
from ._cache_store import CacheStore, InMemoryStore
1314
from ._cancellation_token import CancellationToken
1415
from ._closure_agent import ClosureAgent, ClosureContext
1516
from ._component_config import (
@@ -85,6 +86,8 @@
8586
"AgentMetadata",
8687
"AgentRuntime",
8788
"BaseAgent",
89+
"CacheStore",
90+
"InMemoryStore",
8891
"CancellationToken",
8992
"AgentInstantiationContext",
9093
"TopicId",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Dict, Generic, Optional, Protocol, TypeVar
2+
3+
T = TypeVar("T")
4+
5+
6+
class CacheStore(Protocol, Generic[T]):
7+
"""
8+
This protocol defines the basic interface for store/cache operations.
9+
10+
Sub-classes should handle the lifecycle of underlying storage.
11+
"""
12+
13+
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
14+
"""
15+
Retrieve an item from the store.
16+
17+
Args:
18+
key: The key identifying the item in the store.
19+
default (optional): The default value to return if the key is not found.
20+
Defaults to None.
21+
22+
Returns:
23+
The value associated with the key if found, else the default value.
24+
"""
25+
...
26+
27+
def set(self, key: str, value: T) -> None:
28+
"""
29+
Set an item in the store.
30+
31+
Args:
32+
key: The key under which the item is to be stored.
33+
value: The value to be stored in the store.
34+
"""
35+
...
36+
37+
38+
class InMemoryStore(CacheStore[T]):
39+
def __init__(self) -> None:
40+
self.store: Dict[str, T] = {}
41+
42+
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
43+
return self.store.get(key, default)
44+
45+
def set(self, key: str, value: T) -> None:
46+
self.store[key] = value
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest.mock import Mock
2+
3+
from autogen_core import CacheStore, InMemoryStore
4+
5+
6+
def test_set_and_get_object_key_value() -> None:
7+
mock_store = Mock(spec=CacheStore)
8+
test_key = "test_key"
9+
test_value = object()
10+
mock_store.set(test_key, test_value)
11+
mock_store.get.return_value = test_value
12+
mock_store.set.assert_called_with(test_key, test_value)
13+
assert mock_store.get(test_key) == test_value
14+
15+
16+
def test_get_non_existent_key() -> None:
17+
mock_store = Mock(spec=CacheStore)
18+
key = "non_existent_key"
19+
mock_store.get.return_value = None
20+
assert mock_store.get(key) is None
21+
22+
23+
def test_set_overwrite_existing_key() -> None:
24+
mock_store = Mock(spec=CacheStore)
25+
key = "test_key"
26+
initial_value = "initial_value"
27+
new_value = "new_value"
28+
mock_store.set(key, initial_value)
29+
mock_store.set(key, new_value)
30+
mock_store.get.return_value = new_value
31+
mock_store.set.assert_called_with(key, new_value)
32+
assert mock_store.get(key) == new_value
33+
34+
35+
def test_inmemory_store() -> None:
36+
store = InMemoryStore[int]()
37+
test_key = "test_key"
38+
test_value = 42
39+
store.set(test_key, test_value)
40+
assert store.get(test_key) == test_value
41+
42+
new_value = 2
43+
store.set(test_key, new_value)
44+
assert store.get(test_key) == new_value
45+
46+
key = "non_existent_key"
47+
default_value = 99
48+
assert store.get(key, default_value) == default_value

‎python/packages/autogen-ext/pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ video-surfer = [
4646
"ffmpeg-python",
4747
"openai-whisper",
4848
]
49+
diskcache = [
50+
"diskcache>=5.6.3"
51+
]
52+
redis = [
53+
"redis>=5.2.1"
54+
]
4955

5056
grpc = [
5157
"grpcio~=1.62.0", # TODO: update this once we have a stable version.

‎python/packages/autogen-ext/src/autogen_ext/cache_store/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Any, Optional, TypeVar, cast
2+
3+
import diskcache
4+
from autogen_core import CacheStore
5+
6+
T = TypeVar("T")
7+
8+
9+
class DiskCacheStore(CacheStore[T]):
10+
"""
11+
A typed CacheStore implementation that uses diskcache as the underlying storage.
12+
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
13+
14+
Args:
15+
cache_instance: An instance of diskcache.Cache.
16+
The user is responsible for managing the DiskCache instance's lifetime.
17+
"""
18+
19+
def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
20+
self.cache = cache_instance
21+
22+
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
23+
return cast(Optional[T], self.cache.get(key, default)) # type: ignore[reportUnknownMemberType]
24+
25+
def set(self, key: str, value: T) -> None:
26+
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Any, Optional, TypeVar, cast
2+
3+
import redis
4+
from autogen_core import CacheStore
5+
6+
T = TypeVar("T")
7+
8+
9+
class RedisStore(CacheStore[T]):
10+
"""
11+
A typed CacheStore implementation that uses redis as the underlying storage.
12+
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
13+
14+
Args:
15+
cache_instance: An instance of `redis.Redis`.
16+
The user is responsible for managing the Redis instance's lifetime.
17+
"""
18+
19+
def __init__(self, redis_instance: redis.Redis):
20+
self.cache = redis_instance
21+
22+
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
23+
value = cast(Optional[T], self.cache.get(key))
24+
if value is None:
25+
return default
26+
return value
27+
28+
def set(self, key: str, value: T) -> None:
29+
self.cache.set(key, cast(Any, value))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._chat_completion_cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
2+
3+
__all__ = [
4+
"CHAT_CACHE_VALUE_TYPE",
5+
"ChatCompletionCache",
6+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import hashlib
2+
import json
3+
import warnings
4+
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
5+
6+
from autogen_core import CacheStore, CancellationToken
7+
from autogen_core.models import (
8+
ChatCompletionClient,
9+
CreateResult,
10+
LLMMessage,
11+
ModelCapabilities, # type: ignore
12+
ModelInfo,
13+
RequestUsage,
14+
)
15+
from autogen_core.tools import Tool, ToolSchema
16+
17+
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
18+
19+
20+
class ChatCompletionCache(ChatCompletionClient):
21+
"""
22+
A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches
23+
creation results from an underlying client.
24+
Cache hits do not contribute to token usage of the original client.
25+
26+
Typical Usage:
27+
28+
Lets use caching on disk with `openai` client as an example.
29+
First install `autogen-ext` with the required packages:
30+
31+
.. code-block:: bash
32+
33+
pip install -U "autogen-ext[openai, diskcache]"
34+
35+
And use it as:
36+
37+
.. code-block:: python
38+
39+
import asyncio
40+
import tempfile
41+
42+
from autogen_core.models import UserMessage
43+
from autogen_ext.models.openai import OpenAIChatCompletionClient
44+
from autogen_ext.models.cache import ChatCompletionCache, CHAT_CACHE_VALUE_TYPE
45+
from autogen_ext.cache_store.diskcache import DiskCacheStore
46+
from diskcache import Cache
47+
48+
49+
async def main():
50+
with tempfile.TemporaryDirectory() as tmpdirname:
51+
# Initialize the original client
52+
openai_model_client = OpenAIChatCompletionClient(model="gpt-4o")
53+
54+
# Then initialize the CacheStore, in this case with diskcache.Cache.
55+
# You can also use redis like:
56+
# from autogen_ext.cache_store.redis import RedisStore
57+
# import redis
58+
# redis_instance = redis.Redis()
59+
# cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance)
60+
cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname))
61+
cache_client = ChatCompletionCache(openai_model_client, cache_store)
62+
63+
response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")])
64+
print(response) # Should print response from OpenAI
65+
response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")])
66+
print(response) # Should print cached response
67+
68+
69+
asyncio.run(main())
70+
71+
You can now use the `cached_client` as you would the original client, but with caching enabled.
72+
73+
Args:
74+
client (ChatCompletionClient): The original ChatCompletionClient to wrap.
75+
store (CacheStore): A store object that implements get and set methods.
76+
The user is responsible for managing the store's lifecycle & clearing it (if needed).
77+
"""
78+
79+
def __init__(self, client: ChatCompletionClient, store: CacheStore[CHAT_CACHE_VALUE_TYPE]):
80+
self.client = client
81+
self.store = store
82+
83+
def _check_cache(
84+
self,
85+
messages: Sequence[LLMMessage],
86+
tools: Sequence[Tool | ToolSchema],
87+
json_output: Optional[bool],
88+
extra_create_args: Mapping[str, Any],
89+
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]:
90+
"""
91+
Helper function to check the cache for a result.
92+
Returns a tuple of (cached_result, cache_key).
93+
"""
94+
95+
data = {
96+
"messages": [message.model_dump() for message in messages],
97+
"tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools],
98+
"json_output": json_output,
99+
"extra_create_args": extra_create_args,
100+
}
101+
serialized_data = json.dumps(data, sort_keys=True)
102+
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()
103+
104+
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
105+
if cached_result is not None:
106+
return cached_result, cache_key
107+
108+
return None, cache_key
109+
110+
async def create(
111+
self,
112+
messages: Sequence[LLMMessage],
113+
*,
114+
tools: Sequence[Tool | ToolSchema] = [],
115+
json_output: Optional[bool] = None,
116+
extra_create_args: Mapping[str, Any] = {},
117+
cancellation_token: Optional[CancellationToken] = None,
118+
) -> CreateResult:
119+
"""
120+
Cached version of ChatCompletionClient.create.
121+
If the result of a call to create has been cached, it will be returned immediately
122+
without invoking the underlying client.
123+
124+
NOTE: cancellation_token is ignored for cached results.
125+
"""
126+
cached_result, cache_key = self._check_cache(messages, tools, json_output, extra_create_args)
127+
if cached_result:
128+
assert isinstance(cached_result, CreateResult)
129+
cached_result.cached = True
130+
return cached_result
131+
132+
result = await self.client.create(
133+
messages,
134+
tools=tools,
135+
json_output=json_output,
136+
extra_create_args=extra_create_args,
137+
cancellation_token=cancellation_token,
138+
)
139+
self.store.set(cache_key, result)
140+
return result
141+
142+
def create_stream(
143+
self,
144+
messages: Sequence[LLMMessage],
145+
*,
146+
tools: Sequence[Tool | ToolSchema] = [],
147+
json_output: Optional[bool] = None,
148+
extra_create_args: Mapping[str, Any] = {},
149+
cancellation_token: Optional[CancellationToken] = None,
150+
) -> AsyncGenerator[Union[str, CreateResult], None]:
151+
"""
152+
Cached version of ChatCompletionClient.create_stream.
153+
If the result of a call to create_stream has been cached, it will be returned
154+
without streaming from the underlying client.
155+
156+
NOTE: cancellation_token is ignored for cached results.
157+
"""
158+
159+
async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]:
160+
cached_result, cache_key = self._check_cache(
161+
messages,
162+
tools,
163+
json_output,
164+
extra_create_args,
165+
)
166+
if cached_result:
167+
assert isinstance(cached_result, list)
168+
for result in cached_result:
169+
if isinstance(result, CreateResult):
170+
result.cached = True
171+
yield result
172+
return
173+
174+
result_stream = self.client.create_stream(
175+
messages,
176+
tools=tools,
177+
json_output=json_output,
178+
extra_create_args=extra_create_args,
179+
cancellation_token=cancellation_token,
180+
)
181+
182+
output_results: List[Union[str, CreateResult]] = []
183+
self.store.set(cache_key, output_results)
184+
185+
async for result in result_stream:
186+
output_results.append(result)
187+
yield result
188+
189+
return _generator()
190+
191+
def actual_usage(self) -> RequestUsage:
192+
return self.client.actual_usage()
193+
194+
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
195+
return self.client.count_tokens(messages, tools=tools)
196+
197+
@property
198+
def capabilities(self) -> ModelCapabilities: # type: ignore
199+
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
200+
return self.client.capabilities
201+
202+
@property
203+
def model_info(self) -> ModelInfo:
204+
return self.client.model_info
205+
206+
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
207+
return self.client.remaining_tokens(messages, tools=tools)
208+
209+
def total_usage(self) -> RequestUsage:
210+
return self.client.total_usage()

‎python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class ReplayChatCompletionClient(ChatCompletionClient):
4040
4141
.. code-block:: python
4242
43-
from autogen_ext.models.replay import ReplayChatCompletionClient
4443
from autogen_core.models import UserMessage
44+
from autogen_ext.models.replay import ReplayChatCompletionClient
4545
4646
4747
async def example():
@@ -60,8 +60,8 @@ async def example():
6060
.. code-block:: python
6161
6262
import asyncio
63-
from autogen_ext.models.replay import ReplayChatCompletionClient
6463
from autogen_core.models import UserMessage
64+
from autogen_ext.models.replay import ReplayChatCompletionClient
6565
6666
6767
async def example():
@@ -86,8 +86,8 @@ async def example():
8686
.. code-block:: python
8787
8888
import asyncio
89-
from autogen_ext.models.replay import ReplayChatCompletionClient
9089
from autogen_core.models import UserMessage
90+
from autogen_ext.models.replay import ReplayChatCompletionClient
9191
9292
9393
async def example():
@@ -129,6 +129,7 @@ def __init__(
129129
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
130130
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
131131
self._current_index = 0
132+
self._cached_bool_value = True
132133

133134
async def create(
134135
self,
@@ -148,7 +149,9 @@ async def create(
148149
if isinstance(response, str):
149150
_, output_token_count = self._tokenize(response)
150151
self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count)
151-
response = CreateResult(finish_reason="stop", content=response, usage=self._cur_usage, cached=True)
152+
response = CreateResult(
153+
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
154+
)
152155
else:
153156
self._cur_usage = RequestUsage(
154157
prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens
@@ -207,6 +210,9 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To
207210
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
208211
)
209212

213+
def set_cached_bool_value(self, value: bool) -> None:
214+
self._cached_bool_value = value
215+
210216
def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]:
211217
total_tokens = 0
212218
all_tokens: List[str] = []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import tempfile
2+
3+
import pytest
4+
5+
diskcache = pytest.importorskip("diskcache")
6+
7+
8+
def test_diskcache_store_basic() -> None:
9+
from autogen_ext.cache_store.diskcache import DiskCacheStore
10+
from diskcache import Cache
11+
12+
with tempfile.TemporaryDirectory() as temp_dir:
13+
cache = Cache(temp_dir)
14+
store = DiskCacheStore[int](cache)
15+
test_key = "test_key"
16+
test_value = 42
17+
store.set(test_key, test_value)
18+
assert store.get(test_key) == test_value
19+
20+
new_value = 2
21+
store.set(test_key, new_value)
22+
assert store.get(test_key) == new_value
23+
24+
key = "non_existent_key"
25+
default_value = 99
26+
assert store.get(key, default_value) == default_value
27+
28+
29+
def test_diskcache_with_different_instances() -> None:
30+
from autogen_ext.cache_store.diskcache import DiskCacheStore
31+
from diskcache import Cache
32+
33+
with tempfile.TemporaryDirectory() as temp_dir_1, tempfile.TemporaryDirectory() as temp_dir_2:
34+
cache_1 = Cache(temp_dir_1)
35+
cache_2 = Cache(temp_dir_2)
36+
37+
store_1 = DiskCacheStore[int](cache_1)
38+
store_2 = DiskCacheStore[int](cache_2)
39+
40+
test_key = "test_key"
41+
test_value_1 = 5
42+
test_value_2 = 6
43+
44+
store_1.set(test_key, test_value_1)
45+
assert store_1.get(test_key) == test_value_1
46+
47+
store_2.set(test_key, test_value_2)
48+
assert store_2.get(test_key) == test_value_2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
redis = pytest.importorskip("redis")
6+
7+
8+
def test_redis_store_basic() -> None:
9+
from autogen_ext.cache_store.redis import RedisStore
10+
11+
redis_instance = MagicMock()
12+
store = RedisStore[int](redis_instance)
13+
test_key = "test_key"
14+
test_value = 42
15+
store.set(test_key, test_value)
16+
redis_instance.set.assert_called_with(test_key, test_value)
17+
redis_instance.get.return_value = test_value
18+
assert store.get(test_key) == test_value
19+
20+
new_value = 2
21+
store.set(test_key, new_value)
22+
redis_instance.set.assert_called_with(test_key, new_value)
23+
redis_instance.get.return_value = new_value
24+
assert store.get(test_key) == new_value
25+
26+
key = "non_existent_key"
27+
default_value = 99
28+
redis_instance.get.return_value = None
29+
assert store.get(key, default_value) == default_value
30+
31+
32+
def test_redis_with_different_instances() -> None:
33+
from autogen_ext.cache_store.redis import RedisStore
34+
35+
redis_instance_1 = MagicMock()
36+
redis_instance_2 = MagicMock()
37+
38+
store_1 = RedisStore[int](redis_instance_1)
39+
store_2 = RedisStore[int](redis_instance_2)
40+
41+
test_key = "test_key"
42+
test_value_1 = 5
43+
test_value_2 = 6
44+
45+
store_1.set(test_key, test_value_1)
46+
redis_instance_1.set.assert_called_with(test_key, test_value_1)
47+
redis_instance_1.get.return_value = test_value_1
48+
assert store_1.get(test_key) == test_value_1
49+
50+
store_2.set(test_key, test_value_2)
51+
redis_instance_2.set.assert_called_with(test_key, test_value_2)
52+
redis_instance_2.get.return_value = test_value_2
53+
assert store_2.get(test_key) == test_value_2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import copy
2+
from typing import List, Tuple, Union
3+
4+
import pytest
5+
from autogen_core import InMemoryStore
6+
from autogen_core.models import (
7+
ChatCompletionClient,
8+
CreateResult,
9+
LLMMessage,
10+
SystemMessage,
11+
UserMessage,
12+
)
13+
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
14+
from autogen_ext.models.replay import ReplayChatCompletionClient
15+
16+
17+
def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
18+
num_messages = 3
19+
responses = [f"This is dummy message number {i}" for i in range(num_messages)]
20+
prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)]
21+
system_prompt = SystemMessage(content="This is a system prompt")
22+
replay_client = ReplayChatCompletionClient(responses)
23+
replay_client.set_cached_bool_value(False)
24+
store = InMemoryStore[CHAT_CACHE_VALUE_TYPE]()
25+
cached_client = ChatCompletionCache(replay_client, store)
26+
27+
return responses, prompts, system_prompt, replay_client, cached_client
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_cache_basic_with_args() -> None:
32+
responses, prompts, system_prompt, _, cached_client = get_test_data()
33+
34+
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
35+
assert isinstance(response0, CreateResult)
36+
assert not response0.cached
37+
assert response0.content == responses[0]
38+
39+
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
40+
assert not response1.cached
41+
assert response1.content == responses[1]
42+
43+
# Cached output.
44+
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
45+
assert isinstance(response0, CreateResult)
46+
assert response0_cached.cached
47+
assert response0_cached.content == responses[0]
48+
49+
# Cache miss if args change.
50+
response2 = await cached_client.create(
51+
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=True
52+
)
53+
assert isinstance(response2, CreateResult)
54+
assert not response2.cached
55+
assert response2.content == responses[2]
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_cache_model_and_count_api() -> None:
60+
_, prompts, system_prompt, replay_client, cached_client = get_test_data()
61+
62+
assert replay_client.model_info == cached_client.model_info
63+
assert replay_client.capabilities == cached_client.capabilities
64+
65+
messages: List[LLMMessage] = [system_prompt, UserMessage(content=prompts[0], source="user")]
66+
assert replay_client.count_tokens(messages) == cached_client.count_tokens(messages)
67+
assert replay_client.remaining_tokens(messages) == cached_client.remaining_tokens(messages)
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_cache_token_usage() -> None:
72+
responses, prompts, system_prompt, replay_client, cached_client = get_test_data()
73+
74+
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
75+
assert isinstance(response0, CreateResult)
76+
assert not response0.cached
77+
assert response0.content == responses[0]
78+
actual_usage0 = copy.copy(cached_client.actual_usage())
79+
total_usage0 = copy.copy(cached_client.total_usage())
80+
81+
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
82+
assert not response1.cached
83+
assert response1.content == responses[1]
84+
actual_usage1 = copy.copy(cached_client.actual_usage())
85+
total_usage1 = copy.copy(cached_client.total_usage())
86+
assert total_usage1.prompt_tokens > total_usage0.prompt_tokens
87+
assert total_usage1.completion_tokens > total_usage0.completion_tokens
88+
assert actual_usage1.prompt_tokens == actual_usage0.prompt_tokens
89+
assert actual_usage1.completion_tokens == actual_usage0.completion_tokens
90+
91+
# Cached output.
92+
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
93+
assert isinstance(response0, CreateResult)
94+
assert response0_cached.cached
95+
assert response0_cached.content == responses[0]
96+
total_usage2 = copy.copy(cached_client.total_usage())
97+
assert total_usage2.prompt_tokens == total_usage1.prompt_tokens
98+
assert total_usage2.completion_tokens == total_usage1.completion_tokens
99+
100+
assert cached_client.actual_usage() == replay_client.actual_usage()
101+
assert cached_client.total_usage() == replay_client.total_usage()
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_cache_create_stream() -> None:
106+
_, prompts, system_prompt, _, cached_client = get_test_data()
107+
108+
original_streamed_results: List[Union[str, CreateResult]] = []
109+
async for completion in cached_client.create_stream(
110+
[system_prompt, UserMessage(content=prompts[0], source="user")]
111+
):
112+
original_streamed_results.append(completion)
113+
total_usage0 = copy.copy(cached_client.total_usage())
114+
115+
cached_completion_results: List[Union[str, CreateResult]] = []
116+
async for completion in cached_client.create_stream(
117+
[system_prompt, UserMessage(content=prompts[0], source="user")]
118+
):
119+
cached_completion_results.append(completion)
120+
total_usage1 = copy.copy(cached_client.total_usage())
121+
122+
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
123+
assert total_usage1.completion_tokens == total_usage0.completion_tokens
124+
125+
for original, cached in zip(original_streamed_results, cached_completion_results, strict=False):
126+
if isinstance(original, str):
127+
assert original == cached
128+
elif isinstance(original, CreateResult) and isinstance(cached, CreateResult):
129+
assert original.content == cached.content
130+
assert cached.cached
131+
assert not original.cached
132+
else:
133+
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")

‎python/uv.lock

+56-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.