|
| 1 | +# Extend AgentBench |
| 2 | + |
| 3 | +[🌏中文版](Extension_cn.md) |
| 4 | + |
| 5 | +## Task Introduction |
| 6 | + |
| 7 | +The Task interface is defined as follows: |
| 8 | +``` |
| 9 | +class Task: |
| 10 | + def __init__(self, name: str, concurrency: int = 1, *args, **kwargs): |
| 11 | + self.name = name |
| 12 | + self.concurrency = concurrency |
| 13 | +
|
| 14 | + def get_indices(self) -> List[SampleIndex]: |
| 15 | + raise NotImplementedError() |
| 16 | +
|
| 17 | + async def start_sample( |
| 18 | + self, index: SampleIndex, session: Session |
| 19 | + ) -> TaskSampleExecutionResult: |
| 20 | + raise NotImplementedError() |
| 21 | +
|
| 22 | + def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]: |
| 23 | + raise NotImplementedError() |
| 24 | +
|
| 25 | + def release(self): |
| 26 | + pass |
| 27 | +``` |
| 28 | + |
| 29 | +To implement your own Task, you just need to inherit from Task and implement the corresponding interfaces. The specific interfaces are described as follows: |
| 30 | +- `name`: Task name, usually specified in the config |
| 31 | +- `concurrency`: The maximum concurrency supported within a worker |
| 32 | +- `get_indices`: Returns the indices of all samples |
| 33 | +- `start_sample`: Logic within a single sample, where `index` is the index of the sample to be tested, and `session` is a proxy of the Agent. |
| 34 | +- `calculate_overall`: Calculates the score after all samples have been tested; the return format is arbitrary and will eventually be saved to `overall.json`. |
| 35 | +- `release`: Cleanup tasks that need to be executed after the task_worker process ends. Note that this is after the entire worker process ends, not after a particular sample ends. |
| 36 | + |
| 37 | +The definition of the structures in the program is as follows: |
| 38 | + |
| 39 | +``` |
| 40 | +SampleIndex = Union[int, str] |
| 41 | +JSONSerializable = Union[None, bool, int, float, str, List[Any], Dict[str, Any]] |
| 42 | +
|
| 43 | +class TaskSampleExecutionResult(BaseModel): |
| 44 | + status: SampleStatus = SampleStatus.COMPLETED |
| 45 | + result: JSONSerializable = None |
| 46 | +
|
| 47 | +class TaskOutput(BaseModel): |
| 48 | + index: Union[None, SampleIndex] = None |
| 49 | + status: SampleStatus = SampleStatus.RUNNING # directly from TaskSampleExecutionResult |
| 50 | + result: JSONSerializable = None # directly from TaskSampleExecutionResult |
| 51 | + history: Union[None, List[ChatHistoryItem]] = None |
| 52 | +
|
| 53 | +class SampleStatus(str, Enum): |
| 54 | + RUNNING = "running" |
| 55 | + COMPLETED = "completed" |
| 56 | + AGENT_CONTEXT_LIMIT = "agent context limit" |
| 57 | + AGENT_VALIDATION_FAILED = "agent validation failed" |
| 58 | + AGENT_INVALID_ACTION = "agent invalid action" |
| 59 | + TASK_LIMIT_REACHED = "task limit reached" |
| 60 | + UNKNOWN = "unknown" |
| 61 | + TASK_ERROR = "task error" |
| 62 | +
|
| 63 | +class ChatHistoryItem(BaseModel): |
| 64 | + role: Literal["user", "agent"] |
| 65 | + content: str |
| 66 | +``` |
| 67 | + |
| 68 | +Note that when returning `TaskSampleExecutionResult` in `start_sample`, you should carefully examine the completion status of the sample. If it is completed normally, it should be marked as `COMPLETED`. The relevant data of the completion status of the sample will be automatically counted by the framework. |
| 69 | + |
| 70 | +The `Session` implements the following interfaces: |
| 71 | +- `def inject(self, item: Union[ChatHistoryItem, List[ChatHistoryItem]])`: Inserts one or more historical records. |
| 72 | +- `async def action(self, *injection) -> AgentOutput`: Waits for the Agent's response, and for convenience, it also supports inserting one or more historical records at this time. |
| 73 | + |
| 74 | +The definition of `AgentOutput` is as follows: |
| 75 | +``` |
| 76 | +class AgentOutput(BaseModel): |
| 77 | + status: AgentOutputStatus = AgentOutputStatus.NORMAL |
| 78 | + content: Union[str, None] = None |
| 79 | +
|
| 80 | +class AgentOutputStatus(str, Enum): |
| 81 | + NORMAL = "normal" |
| 82 | + CANCELLED = "cancelled" |
| 83 | + AGENT_CONTEXT_LIMIT = "agent context limit" |
| 84 | +``` |
| 85 | + |
| 86 | +After obtaining `AgentOutput`, you need to handle it carefully and determine whether the `AgentOutputStatus` is normal. If it is not normal, corresponding processing is required. If the status is `CANCELLED`, it means that the client needs to cancel the test of this sample for some reason. At this time, you can quickly end this sample in any way to ensure that it does not affect subsequent tests. |
| 87 | + |
| 88 | +## Implementation Example |
| 89 | + |
| 90 | +A simple implementation is as follows: |
| 91 | + |
| 92 | +``` |
| 93 | +class VirtualTask(Task): |
| 94 | + def __init__(self, *args, **kwargs) -> None: |
| 95 | + super().__init__(name="virtual-task", *args, **kwargs) |
| 96 | +
|
| 97 | + def get_indices(self) -> List[Any]: |
| 98 | + return list(range(10)) |
| 99 | +
|
| 100 | + async def start_sample(self, index, session: Session): |
| 101 | + print("task start sample") |
| 102 | + for loop_times in range(3): |
| 103 | + await asyncio.sleep(1) |
| 104 | + res = await session.action( |
| 105 | + {"role": "user", "content": "Loop: %d" % loop_times} |
| 106 | + ) |
| 107 | + print("TASK", res.content) |
| 108 | + return TaskSampleExecutionResult( |
| 109 | + status=SampleStatus.COMPLETED, |
| 110 | + result={"result": "ok"}, |
| 111 | + ) |
| 112 | +
|
| 113 | + def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]: |
| 114 | + return {"score": 0.4} |
| 115 | +``` |
| 116 | + |
| 117 | +## Migrating from AgentBench v0.1 |
| 118 | + |
| 119 | +### Step 1: Migrate from `get_data` to `get_indices` |
| 120 | + |
| 121 | +The data in the original `get_data` can be directly bound to `self` in the `__init__` method, and the corresponding data can be obtained from `self.data` in `start_sample` according to `index`. On this basis, implement `get_indices`. If the full set of test samples is a list, you can directly return `list(range(len(self.data)))`. |
| 122 | + |
| 123 | +### Step 2: Change `predict_single` to `start_sample` |
| 124 | + |
| 125 | +First, change the original `def` to `async def`. At the same time, change the original `session.action` to `await session.action`. Finally, the return value needs to set an additional `status` compared to the original. This helps to automatically count the reasons for sample errors, which is beneficial for further experimental analysis. |
| 126 | + |
| 127 | +### Step 3: Change `metrics` to `calculate_overall` |
| 128 | + |
| 129 | +This change was initially made to facilitate the counting of samples. If you don't want to change the original `metrics`, you can also create a new `calculate_overall` function and call `self.metrics` within the function. |
| 130 | + |
| 131 | +### Additional Reminder |
| 132 | + |
| 133 | +If you originally overrode the `predict_all` method, it cannot be used in the new framework. |
0 commit comments