Skip to content

Commit

Permalink
Make Memory and Team an ABC (microsoft#5149)
Browse files Browse the repository at this point in the history
* make memory and team an ABC

* update memory test

* update tests
  • Loading branch information
victordibia authored Jan 22, 2025
1 parent 74f411e commit 5e9b24c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import Any, Mapping, Protocol

from typing import Any, Mapping
from abc import ABC, abstractmethod
from ._task import TaskRunner


class Team(TaskRunner, Protocol):
class Team(ABC, TaskRunner):
@abstractmethod
async def reset(self) -> None:
"""Reset the team and all its participants to its initial state."""
...

@abstractmethod
async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the team."""
...

@abstractmethod
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the team."""
...
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import Any, Dict, List, Protocol, Union, runtime_checkable
from typing import Any, Dict, List, Union
from abc import ABC, abstractmethod

from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -48,8 +49,7 @@ class UpdateContextResult(BaseModel):
memories: MemoryQueryResult


@runtime_checkable
class Memory(Protocol):
class Memory(ABC):
"""Protocol defining the interface for memory implementations.
A memory is the storage for data that can be used to enrich or modify the model context.
Expand All @@ -64,6 +64,7 @@ class Memory(Protocol):
See :class:`~autogen_core.memory.ListMemory` for an example implementation.
"""

@abstractmethod
async def update_context(
self,
model_context: ChatCompletionContext,
Expand All @@ -79,6 +80,7 @@ async def update_context(
"""
...

@abstractmethod
async def query(
self,
query: str | MemoryContent,
Expand All @@ -98,6 +100,7 @@ async def query(
"""
...

@abstractmethod
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
"""
Add a new content to memory.
Expand All @@ -108,10 +111,12 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke
"""
...

@abstractmethod
async def clear(self) -> None:
"""Clear all entries from memory."""
...

@abstractmethod
async def close(self) -> None:
"""Clean up any resources used by the memory implementation."""
...
14 changes: 9 additions & 5 deletions python/packages/autogen-core/tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import pytest
from autogen_core import CancellationToken
from autogen_core.memory import (
Expand All @@ -21,19 +22,22 @@ def test_memory_protocol_attributes() -> None:
assert hasattr(Memory, "close")


def test_memory_protocol_runtime_checkable() -> None:
"""Test that Memory protocol is properly runtime-checkable."""
def test_memory_abc_implementation() -> None:
"""Test that Memory ABC is properly implemented."""

class ValidMemory:
class ValidMemory(Memory):
@property
def name(self) -> str:
return "test"

async def update_context(self, context: ChatCompletionContext) -> UpdateContextResult:
async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult:
return UpdateContextResult(memories=MemoryQueryResult(results=[]))

async def query(
self, query: MemoryContent, cancellation_token: CancellationToken | None = None
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
return MemoryQueryResult(results=[])

Expand Down

0 comments on commit 5e9b24c

Please sign in to comment.