-
Notifications
You must be signed in to change notification settings - Fork 5.7k
/
Copy pathtest_message_utils.py
199 lines (178 loc) Β· 6.86 KB
/
test_message_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from openhands.core.message_utils import (
get_token_usage_for_event,
get_token_usage_for_event_id,
)
from openhands.events.event import Event
from openhands.events.tool import ToolCallMetadata
from openhands.llm.metrics import Metrics, TokenUsage
def test_get_token_usage_for_event():
"""Test that we get the single matching usage record (if any) based on the event's model_response.id."""
metrics = Metrics(model_name='test-model')
usage_record = TokenUsage(
model='test-model',
prompt_tokens=10,
completion_tokens=5,
cache_read_tokens=2,
cache_write_tokens=1,
response_id='test-response-id',
)
metrics.add_token_usage(
prompt_tokens=usage_record.prompt_tokens,
completion_tokens=usage_record.completion_tokens,
cache_read_tokens=usage_record.cache_read_tokens,
cache_write_tokens=usage_record.cache_write_tokens,
response_id=usage_record.response_id,
)
# Create an event referencing that response_id
event = Event()
mock_tool_call_metadata = ToolCallMetadata(
tool_call_id='test-tool-call',
function_name='fake_function',
model_response={'id': 'test-response-id'},
total_calls_in_response=1,
)
event._tool_call_metadata = (
mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ...
)
# We should find that usage record
found = get_token_usage_for_event(event, metrics)
assert found is not None
assert found.prompt_tokens == 10
assert found.response_id == 'test-response-id'
# If we change the event's response ID, we won't find anything
mock_tool_call_metadata.model_response.id = 'some-other-id'
found2 = get_token_usage_for_event(event, metrics)
assert found2 is None
# If the event has no tool_call_metadata, also returns None
event._tool_call_metadata = None
found3 = get_token_usage_for_event(event, metrics)
assert found3 is None
def test_get_token_usage_for_event_id():
"""
Test that we search backward from the event with the given id,
finding the first usage record that matches a response_id in that or previous events.
"""
metrics = Metrics(model_name='test-model')
usage_1 = TokenUsage(
model='test-model',
prompt_tokens=12,
completion_tokens=3,
cache_read_tokens=2,
cache_write_tokens=5,
response_id='resp-1',
)
usage_2 = TokenUsage(
model='test-model',
prompt_tokens=7,
completion_tokens=2,
cache_read_tokens=1,
cache_write_tokens=3,
response_id='resp-2',
)
metrics._token_usages.append(usage_1)
metrics._token_usages.append(usage_2)
# Build a list of events
events = []
for i in range(5):
e = Event()
e._id = i
# We'll attach usage_1 to event 1, usage_2 to event 3
if i == 1:
e._tool_call_metadata = ToolCallMetadata(
tool_call_id='tid1',
function_name='fn1',
model_response={'id': 'resp-1'},
total_calls_in_response=1,
)
elif i == 3:
e._tool_call_metadata = ToolCallMetadata(
tool_call_id='tid2',
function_name='fn2',
model_response={'id': 'resp-2'},
total_calls_in_response=1,
)
events.append(e)
# If we ask for event_id=3, we find usage_2 immediately
found_3 = get_token_usage_for_event_id(events, 3, metrics)
assert found_3 is not None
assert found_3.response_id == 'resp-2'
# If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found
found_2 = get_token_usage_for_event_id(events, 2, metrics)
assert found_2 is not None
assert found_2.response_id == 'resp-1'
# If we ask for event_id=0, no usage in event0 or earlier, so return None
found_0 = get_token_usage_for_event_id(events, 0, metrics)
assert found_0 is None
def test_get_token_usage_for_event_fallback():
"""
Verify that if tool_call_metadata.model_response.id is missing or mismatched,
but event.response_id is set to a valid usage ID, we find the usage record via fallback.
"""
metrics = Metrics(model_name='fallback-test')
usage_record = TokenUsage(
model='fallback-test',
prompt_tokens=22,
completion_tokens=8,
cache_read_tokens=3,
cache_write_tokens=2,
response_id='fallback-response-id',
)
metrics.add_token_usage(
prompt_tokens=usage_record.prompt_tokens,
completion_tokens=usage_record.completion_tokens,
cache_read_tokens=usage_record.cache_read_tokens,
cache_write_tokens=usage_record.cache_write_tokens,
response_id=usage_record.response_id,
)
event = Event()
# Provide some mismatched tool_call_metadata:
event._tool_call_metadata = ToolCallMetadata(
tool_call_id='irrelevant-tool-call',
function_name='fake_function',
model_response={'id': 'not-matching-any-usage'},
total_calls_in_response=1,
)
# But also set event.response_id to the actual usage ID
event._response_id = 'fallback-response-id'
found = get_token_usage_for_event(event, metrics)
assert found is not None
assert found.prompt_tokens == 22
assert found.response_id == 'fallback-response-id'
def test_get_token_usage_for_event_id_fallback():
"""
Verify that get_token_usage_for_event_id also falls back to event.response_id
if tool_call_metadata.model_response.id is missing or mismatched.
"""
# NOTE: this should never happen (tm), but there is a hint in the code that it might:
# message_utils.py: 166 ("(overwrites any previous message with the same response_id)")
# so we'll handle it gracefully.
metrics = Metrics(model_name='fallback-test')
usage_record = TokenUsage(
model='fallback-test',
prompt_tokens=15,
completion_tokens=4,
cache_read_tokens=1,
cache_write_tokens=0,
response_id='resp-fallback',
)
metrics.token_usages.append(usage_record)
events = []
for i in range(3):
e = Event()
e._id = i
if i == 1:
# Mismatch in tool_call_metadata
e._tool_call_metadata = ToolCallMetadata(
tool_call_id='tool-123',
function_name='whatever',
model_response={'id': 'no-such-response'},
total_calls_in_response=1,
)
# But the event's top-level response_id is correct
e._response_id = 'resp-fallback'
events.append(e)
# Searching from event_id=2 goes back to event1, which has fallback response_id
found_usage = get_token_usage_for_event_id(events, 2, metrics)
assert found_usage is not None
assert found_usage.response_id == 'resp-fallback'
assert found_usage.prompt_tokens == 15