Skip to content

Commit bc293e7

Browse files
committedMay 30, 2024
add extension docs
1 parent 2f3c343 commit bc293e7

File tree

3 files changed

+272
-0
lines changed

3 files changed

+272
-0
lines changed
 

‎README.md

+4
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ launching:
168168
| os | ~5s | < 500M |
169169
| kd | ~5s | < 500M |
170170

171+
## Extending AgentBench
172+
173+
If you wish to add new tasks to AgentBench, you may refer to [Extension Guide](docs/Extension_en.md).
174+
171175
## References
172176

173177
Avalon task is merged from [AvalonBench](https://github.com/jonathanmli/Avalon-LLM/), which implements a multi-agent framework.

‎docs/Extension_cn.md

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# 扩展AgentBench
2+
3+
[🌏English](Extension_en.md)
4+
5+
## Task介绍
6+
7+
Task接口的定义如下:
8+
```python
9+
class Task:
10+
def __init__(self, name: str, concurrency: int = 1, *args, **kwargs):
11+
self.name = name
12+
self.concurrency = concurrency
13+
14+
def get_indices(self) -> List[SampleIndex]:
15+
raise NotImplementedError()
16+
17+
async def start_sample(
18+
self, index: SampleIndex, session: Session
19+
) -> TaskSampleExecutionResult:
20+
raise NotImplementedError()
21+
22+
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
23+
raise NotImplementedError()
24+
25+
def release(self):
26+
pass
27+
```
28+
29+
如果想要实现自己的Task,只需要继承自Task并实现相应的接口即可。具体接口含义如下:
30+
- `name`: 任务名称,通常是在config中指定
31+
- `concurrency`:一个worker内部支持的最大并发
32+
- `get_indices`:返回所有测例的索引
33+
- `start_sample`:一条测例内的逻辑,其中`index`是待测的测例的索引,`session`是Agent的一个代理。
34+
- `calculate_overall`:所有测例测试完以后计算得分,返回格式任意,最终会被保存到`overall.json`中。
35+
- `release`:task_worker进程结束后需要执行的清理。注意是整个worker进程结束后,而不是某个测例结束后。
36+
37+
程序中结构体的定义如下:
38+
```python
39+
SampleIndex = Union[int, str]
40+
JSONSerializable = Union[None, bool, int, float, str, List[Any], Dict[str, Any]]
41+
42+
class TaskSampleExecutionResult(BaseModel):
43+
status: SampleStatus = SampleStatus.COMPLETED
44+
result: JSONSerializable = None
45+
46+
class TaskOutput(BaseModel):
47+
index: Union[None, SampleIndex] = None
48+
status: SampleStatus = SampleStatus.RUNNING # directly from TaskSampleExecutionResult
49+
result: JSONSerializable = None # directly from TaskSampleExecutionResult
50+
history: Union[None, List[ChatHistoryItem]] = None
51+
52+
class SampleStatus(str, Enum):
53+
RUNNING = "running"
54+
COMPLETED = "completed"
55+
AGENT_CONTEXT_LIMIT = "agent context limit"
56+
AGENT_VALIDATION_FAILED = "agent validation failed"
57+
AGENT_INVALID_ACTION = "agent invalid action"
58+
TASK_LIMIT_REACHED = "task limit reached"
59+
UNKNOWN = "unknown"
60+
TASK_ERROR = "task error"
61+
62+
class ChatHistoryItem(BaseModel):
63+
role: Literal["user", "agent"]
64+
content: str
65+
```
66+
67+
需要注意的是,`start_sample`在返回`TaskSampleExecutionResult`的时候应当仔细考察本条测例的完成状态,如果正常完成应当标记为`COMPLETED`,测例完成状态的相关数据将被框架自动统计。
68+
69+
`Session`实现了如下接口:
70+
- `def inject(self, item: Union[ChatHistoryItem, List[ChatHistoryItem]])`:插入一条或多条历史记录。
71+
- `async def action(self, *injection) -> AgentOutput`:等待Agent的响应,为了方便起见此时也支持同时插入一条或多条历史记录。
72+
73+
`AgentOutput`的定义如下:
74+
```python
75+
class AgentOutput(BaseModel):
76+
status: AgentOutputStatus = AgentOutputStatus.NORMAL
77+
content: Union[str, None] = None
78+
79+
class AgentOutputStatus(str, Enum):
80+
NORMAL = "normal"
81+
CANCELLED = "cancelled"
82+
AGENT_CONTEXT_LIMIT = "agent context limit"
83+
```
84+
85+
在得到`AgentOutput`以后需要小心处理,需要判断`AgentOutputStatus`是否是正常,如果不正常需要做响应的处理。
86+
如果状态是`CANCELLED`,则意味着客户端出于某种原因需要取消这条测例的测试,此时可以以任意方式迅速结束此条测例,保证不影响后续测试即可。
87+
88+
## 实现示例
89+
90+
一个简单的实现如下:
91+
92+
```python
93+
class VirtualTask(Task):
94+
def __init__(self, *args, **kwargs) -> None:
95+
super().__init__(name="virtual-task", *args, **kwargs)
96+
97+
def get_indices(self) -> List[Any]:
98+
return list(range(10))
99+
100+
async def start_sample(self, index, session: Session):
101+
print("task start sample")
102+
for loop_times in range(3):
103+
await asyncio.sleep(1)
104+
res = await session.action(
105+
{"role": "user", "content": "Loop: %d" % loop_times}
106+
)
107+
print("TASK", res.content)
108+
return TaskSampleExecutionResult(
109+
status=SampleStatus.COMPLETED,
110+
result={"result": "ok"},
111+
)
112+
113+
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
114+
return {"score": 0.4}
115+
```
116+
117+
## 从AgentBench v0.1迁移
118+
119+
### step 1 从get_data迁移至get_indices
120+
121+
原先`get_data`中的数据可以直接在`__init__`中绑定到`self`上,在`start_sample`中再根据`index``self.data`中自行获取相应数据。
122+
在这一步的基础上同时实现`get_indices`,如果测试样本全集是一个列表,可以直接返回`list(range(len(self.data)))`
123+
124+
### step 2 将predict_single改为start_sample
125+
126+
首先需要将原本的`def`改为`async def`。同时将原本的`session.action`改为`await session.action`
127+
最后返回值与原先相比需要额外设置一个`status`。这有助于自动统计样本错误的原因,有利于进一步的实验分析。
128+
129+
### step 3 将metrics改为calculate_overall
130+
131+
这个更改最初是为了更方便的统计样本。如果你不愿意更改原来的`metrics`,也可以新建一个`calculate_overall`函数,在函数内调用`self.metrics`
132+
133+
### 额外提醒
134+
135+
如果你原先覆写了`predict_all`方法,这在新框架下是无法使用的。

‎docs/Extension_en.md

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Extend AgentBench
2+
3+
[🌏中文版](Extension_cn.md)
4+
5+
## Task Introduction
6+
7+
The Task interface is defined as follows:
8+
```
9+
class Task:
10+
def __init__(self, name: str, concurrency: int = 1, *args, **kwargs):
11+
self.name = name
12+
self.concurrency = concurrency
13+
14+
def get_indices(self) -> List[SampleIndex]:
15+
raise NotImplementedError()
16+
17+
async def start_sample(
18+
self, index: SampleIndex, session: Session
19+
) -> TaskSampleExecutionResult:
20+
raise NotImplementedError()
21+
22+
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
23+
raise NotImplementedError()
24+
25+
def release(self):
26+
pass
27+
```
28+
29+
To implement your own Task, you just need to inherit from Task and implement the corresponding interfaces. The specific interfaces are described as follows:
30+
- `name`: Task name, usually specified in the config
31+
- `concurrency`: The maximum concurrency supported within a worker
32+
- `get_indices`: Returns the indices of all samples
33+
- `start_sample`: Logic within a single sample, where `index` is the index of the sample to be tested, and `session` is a proxy of the Agent.
34+
- `calculate_overall`: Calculates the score after all samples have been tested; the return format is arbitrary and will eventually be saved to `overall.json`.
35+
- `release`: Cleanup tasks that need to be executed after the task_worker process ends. Note that this is after the entire worker process ends, not after a particular sample ends.
36+
37+
The definition of the structures in the program is as follows:
38+
39+
```
40+
SampleIndex = Union[int, str]
41+
JSONSerializable = Union[None, bool, int, float, str, List[Any], Dict[str, Any]]
42+
43+
class TaskSampleExecutionResult(BaseModel):
44+
status: SampleStatus = SampleStatus.COMPLETED
45+
result: JSONSerializable = None
46+
47+
class TaskOutput(BaseModel):
48+
index: Union[None, SampleIndex] = None
49+
status: SampleStatus = SampleStatus.RUNNING # directly from TaskSampleExecutionResult
50+
result: JSONSerializable = None # directly from TaskSampleExecutionResult
51+
history: Union[None, List[ChatHistoryItem]] = None
52+
53+
class SampleStatus(str, Enum):
54+
RUNNING = "running"
55+
COMPLETED = "completed"
56+
AGENT_CONTEXT_LIMIT = "agent context limit"
57+
AGENT_VALIDATION_FAILED = "agent validation failed"
58+
AGENT_INVALID_ACTION = "agent invalid action"
59+
TASK_LIMIT_REACHED = "task limit reached"
60+
UNKNOWN = "unknown"
61+
TASK_ERROR = "task error"
62+
63+
class ChatHistoryItem(BaseModel):
64+
role: Literal["user", "agent"]
65+
content: str
66+
```
67+
68+
Note that when returning `TaskSampleExecutionResult` in `start_sample`, you should carefully examine the completion status of the sample. If it is completed normally, it should be marked as `COMPLETED`. The relevant data of the completion status of the sample will be automatically counted by the framework.
69+
70+
The `Session` implements the following interfaces:
71+
- `def inject(self, item: Union[ChatHistoryItem, List[ChatHistoryItem]])`: Inserts one or more historical records.
72+
- `async def action(self, *injection) -> AgentOutput`: Waits for the Agent's response, and for convenience, it also supports inserting one or more historical records at this time.
73+
74+
The definition of `AgentOutput` is as follows:
75+
```
76+
class AgentOutput(BaseModel):
77+
status: AgentOutputStatus = AgentOutputStatus.NORMAL
78+
content: Union[str, None] = None
79+
80+
class AgentOutputStatus(str, Enum):
81+
NORMAL = "normal"
82+
CANCELLED = "cancelled"
83+
AGENT_CONTEXT_LIMIT = "agent context limit"
84+
```
85+
86+
After obtaining `AgentOutput`, you need to handle it carefully and determine whether the `AgentOutputStatus` is normal. If it is not normal, corresponding processing is required. If the status is `CANCELLED`, it means that the client needs to cancel the test of this sample for some reason. At this time, you can quickly end this sample in any way to ensure that it does not affect subsequent tests.
87+
88+
## Implementation Example
89+
90+
A simple implementation is as follows:
91+
92+
```
93+
class VirtualTask(Task):
94+
def __init__(self, *args, **kwargs) -> None:
95+
super().__init__(name="virtual-task", *args, **kwargs)
96+
97+
def get_indices(self) -> List[Any]:
98+
return list(range(10))
99+
100+
async def start_sample(self, index, session: Session):
101+
print("task start sample")
102+
for loop_times in range(3):
103+
await asyncio.sleep(1)
104+
res = await session.action(
105+
{"role": "user", "content": "Loop: %d" % loop_times}
106+
)
107+
print("TASK", res.content)
108+
return TaskSampleExecutionResult(
109+
status=SampleStatus.COMPLETED,
110+
result={"result": "ok"},
111+
)
112+
113+
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
114+
return {"score": 0.4}
115+
```
116+
117+
## Migrating from AgentBench v0.1
118+
119+
### Step 1: Migrate from `get_data` to `get_indices`
120+
121+
The data in the original `get_data` can be directly bound to `self` in the `__init__` method, and the corresponding data can be obtained from `self.data` in `start_sample` according to `index`. On this basis, implement `get_indices`. If the full set of test samples is a list, you can directly return `list(range(len(self.data)))`.
122+
123+
### Step 2: Change `predict_single` to `start_sample`
124+
125+
First, change the original `def` to `async def`. At the same time, change the original `session.action` to `await session.action`. Finally, the return value needs to set an additional `status` compared to the original. This helps to automatically count the reasons for sample errors, which is beneficial for further experimental analysis.
126+
127+
### Step 3: Change `metrics` to `calculate_overall`
128+
129+
This change was initially made to facilitate the counting of samples. If you don't want to change the original `metrics`, you can also create a new `calculate_overall` function and call `self.metrics` within the function.
130+
131+
### Additional Reminder
132+
133+
If you originally overrode the `predict_all` method, it cannot be used in the new framework.

0 commit comments

Comments
 (0)
Please sign in to comment.