-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllm_optimization_tester.py
367 lines (319 loc) · 13.4 KB
/
llm_optimization_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import time
import os
import anthropic
import logging
from typing import Dict, List, Optional, TypedDict
from dataclasses import dataclass
import json
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
class TestCase(TypedDict):
name: str
description: str
standard_code: str
optimized_code: str
tasks: List[str]
expected_results: List[Dict]
@dataclass
class TestResult:
completion_time: float
accuracy_score: float
context_questions: int
token_usage: int
errors: List[str]
output: str
class LLMOptimizationTester:
"""
FILE_PURPOSE: Test framework for comparing LLM performance on standard vs. optimized code
SYSTEM_CONTEXT: Part of the LLM codebase optimization research toolkit
TECHNICAL_DEPENDENCIES:
- Python 3.9+
- LLM API access
USAGE_EXAMPLE:
tester = LLMOptimizationTester()
results = tester.run_comparison_test(test_case)
tester.generate_report(results)
"""
def __init__(self):
"""
CONTEXT: Initializes the testing framework
CONFIGURATION: Loads from config.json if present
"""
self.results: Dict[str, List[TestResult]] = {
"standard": [],
"optimized": []
}
# Load configuration
try:
with open('config.json', 'r') as f:
self.config = json.load(f)
logger.info("Loaded configuration from config.json")
except FileNotFoundError:
logger.warning("config.json not found, using default settings")
self.config = {
"api": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4000,
"temperature": 0.0
}
}
def run_comparison_test(self, test_case: TestCase, standard_result: Optional[TestResult] = None) -> Dict[str, TestResult]:
"""
CONTEXT: Runs comparative analysis between standard and optimized code
PROCESS:
1. Use provided standard results or run standard implementation once
2. Run optimized implementation and compare against standard results
3. Measure and compare performance metrics
Args:
test_case: The test case to run
standard_result: Optional pre-computed standard implementation results
"""
results = {}
logger.info(f"Starting comparison test: {test_case['name']}")
logger.info(f"Number of tasks to evaluate: {len(test_case['tasks'])}")
# Use provided standard result or run standard implementation
if standard_result is None:
logger.info("Testing standard implementation...")
start_time = time.time()
standard_result = self._run_single_test(
"standard",
test_case["standard_code"],
test_case["tasks"]
)
standard_time = time.time() - start_time
logger.info(f"Standard implementation test completed in {standard_time:.2f}s")
else:
logger.info("Using provided standard implementation results")
# Store standard results
results["standard"] = standard_result
# Test optimized version
logger.info("Testing LLM-optimized implementation...")
start_time = time.time()
results["optimized"] = self._run_single_test(
"optimized",
test_case["optimized_code"],
test_case["tasks"]
)
optimized_time = time.time() - start_time
logger.info(f"Optimized implementation test completed in {optimized_time:.2f}s")
return results
def _run_single_test(self, version: str, code: str, tasks: List[str]) -> TestResult:
"""
CONTEXT: Executes a single test run using Anthropic's Sonnet model
METRICS_TRACKED:
- Completion time
- Accuracy
- Context questions needed
- Token usage
"""
start_time = time.time()
errors = []
context_questions = 0
total_tokens = 0
accuracy_scores = []
# Initialize Anthropic client
try:
client = anthropic.Anthropic(
api_key=os.getenv('ANTHROPIC_API_KEY')
)
logger.debug("Anthropic client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Anthropic client: {e}")
raise
try:
for i, task in enumerate(tasks, 1):
logger.info(f"[{version}] Processing task {i}/{len(tasks)}: {task[:100]}...")
# Construct the prompt
prompt = f"""
You are a software engineer analyzing and modifying code.
Here is the code to work with:
{code}
Your task is: {task}
Please complete this task. If you need any clarification,
ask a question starting with 'QUESTION:'. Otherwise,
provide your solution starting with 'SOLUTION:'.
"""
# Make API call with retries
max_retries = 3
base_delay = 1 # Base delay in seconds
success = False
message = None
retry_count = 0
while not success and retry_count < max_retries:
try:
if retry_count > 0:
delay = base_delay * (2 ** (retry_count - 1)) # Exponential backoff
logger.info(f"[{version}] Retrying task {i} after {delay}s (attempt {retry_count + 1}/{max_retries})")
time.sleep(delay)
message = client.messages.create(
model=self.config["api"]["model"],
max_tokens=self.config["api"]["max_tokens"],
temperature=self.config["api"]["temperature"],
system="You are a skilled software engineer focused on code analysis and modification.",
messages=[{
"role": "user",
"content": prompt
}]
)
success = True
logger.debug(f"[{version}] API call successful for task {i}")
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"[{version}] API call failed for task {i} after {max_retries} attempts: {e}")
errors.append(f"Task {i} API error after {max_retries} retries: {str(e)}")
continue
logger.warning(f"[{version}] API call failed for task {i} (attempt {retry_count}/{max_retries}): {e}")
if not success:
continue
# Process response
response = message.content[0].text
total_tokens += message.usage.input_tokens + message.usage.output_tokens
# Track context questions
if response.strip().startswith("QUESTION:"):
context_questions += 1
logger.info(f"[{version}] Task {i} required clarification: {response[:100]}...")
continue
# Calculate accuracy score
accuracy = 1.0 if response.startswith("SOLUTION:") else 0.5
accuracy_scores.append(accuracy)
logger.info(f"[{version}] Task {i} completed with accuracy score: {accuracy}")
except Exception as e:
logger.error(f"[{version}] Test execution failed: {e}")
errors.append(f"Test execution error: {str(e)}")
completion_time = time.time() - start_time
avg_accuracy = sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0.0
logger.info(f"[{version}] Test summary:")
logger.info(f"- Completion time: {completion_time:.2f}s")
logger.info(f"- Average accuracy: {avg_accuracy:.2%}")
logger.info(f"- Context questions needed: {context_questions}")
logger.info(f"- Total tokens used: {total_tokens}")
if errors:
logger.warning(f"- Errors encountered: {len(errors)}")
return TestResult(
completion_time=completion_time,
accuracy_score=avg_accuracy,
context_questions=context_questions,
token_usage=total_tokens,
errors=errors,
output=response if response else ''
)
def analyze_code_outputs(self, prompt: str) -> str:
"""
CONTEXT: Analyzes code outputs using LLM
PURPOSE: Compare different implementations and provide insights
"""
try:
client = anthropic.Anthropic(
api_key=os.getenv('ANTHROPIC_API_KEY')
)
logger.debug("Anthropic client initialized successfully for analysis")
message = client.messages.create(
model=self.config["api"]["model"],
max_tokens=self.config["api"]["max_tokens"],
temperature=self.config["api"]["temperature"],
system="You are a skilled software engineer focused on code analysis and comparison.",
messages=[{
"role": "user",
"content": prompt
}]
)
return message.content[0].text
except Exception as e:
logger.error(f"Failed to analyze code outputs: {e}")
return f"Analysis failed: {str(e)}"
def generate_report(self, results: Dict[str, TestResult]) -> Dict:
logger.info("Generating comparison report...")
"""
CONTEXT: Generates detailed comparison report
OUTPUT: Performance metrics and analysis
"""
return {
"performance_comparison": {
"time_difference": results["optimized"].completion_time - results["standard"].completion_time,
"accuracy_improvement": results["optimized"].accuracy_score - results["standard"].accuracy_score,
"context_questions_reduced": results["standard"].context_questions - results["optimized"].context_questions,
"token_usage_difference": results["optimized"].token_usage - results["standard"].token_usage
},
"conclusion": self._generate_conclusion(results)
}
def _generate_conclusion(self, results: Dict[str, TestResult]) -> str:
"""
CONTEXT: Analyzes results to generate meaningful conclusions
ANALYSIS_POINTS:
- Performance improvements
- Resource efficiency
- Error reduction
"""
# Placeholder for actual analysis logic
return "Analysis of optimization effectiveness"
def create_sample_test_case() -> TestCase:
"""
CONTEXT: Creates a sample test case for demonstration
USAGE: For testing framework validation
"""
return {
"name": "User Authentication Module",
"description": "Testing optimization effects on authentication code understanding",
"standard_code": """
def authenticate_user(credentials):
# Validate credentials
if not credentials.get('username') or not credentials.get('password'):
return None
# Check database
user = db.find_user(credentials['username'])
if not user or not check_password(user, credentials['password']):
return None
return user
""",
"optimized_code": """
# CONTEXT: Handles user authentication against database
# DEPENDENCIES: database.py, password_utils.py
# SIDE_EFFECTS: Logs authentication attempts
# RETURNS: User object or None if authentication fails
def authenticate_user(credentials: Dict[str, str]) -> Optional[User]:
'''
FILE_PURPOSE: User authentication and validation
BUSINESS_RULES:
- Username and password required
- Password must match stored hash
TECHNICAL_DEPENDENCIES:
- PostgreSQL database
- Bcrypt password hashing
'''
# === INPUT VALIDATION ===
if not credentials.get('username') or not credentials.get('password'):
return None
# === DATABASE OPERATIONS ===
user = db.find_user(credentials['username'])
if not user or not check_password(user, credentials['password']):
return None
return user
""",
"tasks": [
"Explain the authentication flow",
"Add MFA support",
"Fix password validation bug"
],
"expected_results": []
}
if __name__ == "__main__":
"""
CONTEXT: Example usage of the testing framework
PURPOSE: Demonstrate framework capabilities
"""
# Create test case
test_case = create_sample_test_case()
# Initialize tester
tester = LLMOptimizationTester()
# Run comparison
results = tester.run_comparison_test(test_case)
# Generate and print report
report = tester.generate_report(results)
print(json.dumps(report, indent=2))