Skip to content

Commit 83b7c32

Browse files
halleriteGitHoobarWendong-Fan
authored
refactor: Implement extraction as strategy pattern (#1742)
Co-authored-by: Rishabh <[email protected]> Co-authored-by: Wendong-Fan <[email protected]>
1 parent ff6d719 commit 83b7c32

File tree

8 files changed

+690
-80
lines changed

8 files changed

+690
-80
lines changed

camel/environments/base.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,26 @@ def __init__(
151151
r"""Initialize the environment.
152152
153153
Args:
154-
dataset: Dataset to sample questions from.
155-
verifier: Verifier to check responses.
156-
extractor: Extractor to process LLM responses.
157-
max_steps: Maximum steps per episode.
158-
teacher_agent: Optional agent for reward shaping and hints
159-
curriculum_config: Configuration for curriculum learning including:
154+
dataset (BaseDataset): Dataset to sample questions from.
155+
verifier (BaseVerifier): Verifier to check responses.
156+
extractor (BaseExtractor): Extractor to process LLM responses.
157+
max_steps (Optional[int]): Maximum steps per episode. (default:
158+
:obj:`None`)
159+
teacher_agent (Optional[ChatAgent]): Optional agent for reward
160+
shaping and hints. (default: :obj:`None`)
161+
curriculum_config (Optional[Dict[str, Any]]): Configuration for
162+
curriculum learning including:
160163
- difficulty_levels: List of available difficulty levels
161164
- promotion_threshold: Score needed to advance
162165
- demotion_threshold: Score triggering level decrease
163166
- min_questions_per_level: Questions before promotion
164-
practice_env_config: Configuration for practice environments:
167+
(default: :obj:`None`)
168+
practice_env_config (Optional[Dict[str, Any]]): Configuration for
169+
practice environments:
165170
- max_practice_envs: Maximum concurrent environments
166171
- difficulty_range: Allowed difficulty variation
167172
- focus_areas: Specific skills to practice
173+
(default: :obj:`None`)
168174
**kwargs: Additional environment parameters.
169175
"""
170176
self.dataset = dataset
@@ -289,7 +295,9 @@ async def step(self, action: Action) -> StepResult:
289295
# extract verifiable part from llm response
290296
extraction_result = await self.extractor.extract(action.llm_response)
291297

292-
# TODO: extract executable llm response specifically
298+
# Ensure extraction_result is a string
299+
if extraction_result is None:
300+
extraction_result = ""
293301

294302
# verify the extracted
295303
verification_result = await self.verifier.verify(

camel/extractors/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14-
from .base import BaseExtractor
14+
from .base import BaseExtractor, BaseExtractorStrategy
1515

16-
__all__ = ["BaseExtractor"]
16+
__all__ = ["BaseExtractor", "BaseExtractorStrategy"]

camel/extractors/base.py

+86-64
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,47 @@
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
1414

15+
import asyncio
1516
from abc import ABC, abstractmethod
1617
from types import TracebackType
17-
from typing import Any, Dict, Optional, Type
18-
19-
from typing_extensions import Self
18+
from typing import Any, Dict, List, Optional, Type
2019

2120
from camel.logger import get_logger
2221
from camel.utils import BatchProcessor
2322

2423
logger = get_logger(__name__)
2524

2625

27-
class BaseExtractor(ABC):
28-
r"""Base class for all response extractors.
26+
class BaseExtractorStrategy(ABC):
27+
r"""Abstract base class for extraction strategies."""
28+
29+
@abstractmethod
30+
async def extract(self, text: str) -> Optional[str]:
31+
r"""Asynchronously extracts relevant parts from text.
32+
33+
Args:
34+
text (str): The input text to process.
35+
36+
Returns:
37+
Optional[str]: Extracted str if successful, otherwise None.
38+
"""
39+
pass
40+
41+
42+
class BaseExtractor:
43+
r"""Base class for response extractors with a fixed strategy pipeline.
2944
30-
An extractor takes the response and extracts the relevant parts,
31-
converting them into a format that the verifier can handle.
32-
Implements async context manager protocol for proper resource management.
45+
This extractor:
46+
- Uses a **fixed multi-stage pipeline** of extraction strategies.
47+
- Tries **each strategy in order** within a stage until one succeeds.
48+
- Feeds the **output of one stage into the next** for processing.
49+
- Supports **async execution** for efficient processing.
50+
- Provides **batch processing and resource monitoring** options.
3351
"""
3452

3553
def __init__(
3654
self,
55+
pipeline: List[List[BaseExtractorStrategy]],
3756
cache_templates: bool = True,
3857
max_cache_size: int = 1000,
3958
extraction_timeout: float = 30.0,
@@ -43,9 +62,12 @@ def __init__(
4362
memory_threshold: float = 85.0,
4463
**kwargs,
4564
):
46-
r"""Initialize the extractor.
65+
r"""Initialize the extractor with a multi-stage strategy pipeline.
4766
4867
Args:
68+
pipeline (List[List[BaseExtractorStrategy]]):
69+
A fixed list of lists where each list represents a stage
70+
containing extractor strategies executed in order.
4971
cache_templates (bool): Whether to cache extraction templates.
5072
(default: :obj:`True`)
5173
max_cache_size (int): Maximum number of templates to cache.
@@ -61,11 +83,8 @@ def __init__(
6183
memory_threshold (float): Memory usage percentage threshold for
6284
scaling down. (default: :obj:`85.0`)
6385
**kwargs: Additional extractor parameters.
64-
65-
Raises:
66-
ValueError: If invalid parameter values are provided
6786
"""
68-
# Store all parameters in metadata dict for compatibility
87+
6988
self._metadata = {
7089
'cache_templates': cache_templates,
7190
'max_cache_size': max_cache_size,
@@ -81,14 +100,7 @@ def __init__(
81100
self._cache: Dict[str, Any] = {}
82101
self._batch_processor: Optional[BatchProcessor] = None
83102

84-
# Store configuration parameters
85-
self._cache_templates = cache_templates
86-
self._max_cache_size = max_cache_size
87-
self._extraction_timeout = extraction_timeout
88-
self._batch_size = batch_size
89-
self._monitoring_interval = monitoring_interval
90-
self._cpu_threshold = cpu_threshold
91-
self._memory_threshold = memory_threshold
103+
self._pipeline = pipeline
92104

93105
async def setup(self) -> None:
94106
r"""Set up the extractor with necessary resources.
@@ -106,17 +118,15 @@ async def setup(self) -> None:
106118
return
107119

108120
try:
109-
# Initialize template cache if enabled
110-
if self._cache_templates:
121+
if self._metadata["cache_templates"]:
111122
self._template_cache: Dict[str, Any] = {}
112123

113-
# Set up batch processing if needed
114-
if self._batch_size > 1:
124+
if self._metadata["batch_size"] > 1:
115125
self._batch_processor = BatchProcessor(
116-
initial_batch_size=self._batch_size,
117-
monitoring_interval=self._monitoring_interval,
118-
cpu_threshold=self._cpu_threshold,
119-
memory_threshold=self._memory_threshold,
126+
initial_batch_size=self._metadata["batch_size"],
127+
monitoring_interval=self._metadata["monitoring_interval"],
128+
cpu_threshold=self._metadata["cpu_threshold"],
129+
memory_threshold=self._metadata["memory_threshold"],
120130
)
121131

122132
self._is_setup = True
@@ -171,13 +181,6 @@ async def cleanup(self) -> None:
171181
)
172182

173183
# Preserve init config in metadata
174-
self._metadata = {
175-
'cache_templates': self._cache_templates,
176-
'max_cache_size': self._max_cache_size,
177-
'extraction_timeout': self._extraction_timeout,
178-
'batch_size': self._batch_size,
179-
}
180-
181184
if not errors:
182185
logger.info(
183186
f"{self.__class__.__name__} cleaned up successfully"
@@ -187,23 +190,19 @@ async def cleanup(self) -> None:
187190
errors.append(f"Unexpected error during cleanup: {e}")
188191

189192
finally:
190-
# Always mark as uninitialized, even if cleanup fails
191193
self._is_setup = False
192194
self._batch_processor = None
193195

194196
if errors:
195-
error_msg = (
196-
f"Errors during {self.__class__.__name__} cleanup: "
197-
f"{'; '.join(errors)}"
198-
)
197+
error_msg = f"Errors during cleanup: {'; '.join(errors)}"
199198
logger.error(error_msg)
200199
raise RuntimeError(error_msg)
201200

202-
async def __aenter__(self) -> Self:
201+
async def __aenter__(self) -> "BaseExtractor":
203202
r"""Async context manager entry.
204203
205204
Returns:
206-
Self reference for context manager usage.
205+
BaseExtractor: The initialized extractor instance.
207206
"""
208207
await self.setup()
209208
return self
@@ -226,38 +225,61 @@ async def __aexit__(
226225
"""
227226
await self.cleanup()
228227

229-
@abstractmethod
230-
async def extract(
231-
self, response: str, context: Optional[Dict[str, Any]] = None
232-
) -> str:
233-
r"""Extract relevant parts from a response.
234-
235-
Extracts:
236-
1. Final answer or output
237-
2. Chain of thought reasoning steps
238-
3. Difficulty assessment
228+
async def extract(self, response: str) -> Optional[str]:
229+
r"""Extracts a normalized, comparable part of the LLM response
230+
using the fixed multi-stage strategy pipeline.
239231
240232
Args:
241-
response (str): Raw response from agent generation.
242-
context (Optional[Dict[str, Any]]): Optional context for
243-
extraction like:
244-
- final_answer
245-
- rationale
246-
- complexity
233+
response (str): The raw response text.
247234
248235
Returns:
249-
str: Extracted content string.
236+
Optional[str]: Extracted data if successful, otherwise None.
250237
251238
Raises:
252239
ValueError: If response is empty or invalid.
253-
NotImplementedError: If no implementation is provided.
254240
RuntimeError: If extractor is not initialized.
255241
"""
256242
if not self._is_setup:
257243
raise RuntimeError(
258-
f"{self.__class__.__name__} must be initialized "
259-
"before extraction"
244+
"Extractor must be initialized before extraction"
260245
)
261246
if not response or not response.strip():
262247
raise ValueError("Empty or whitespace-only response")
263-
raise NotImplementedError("Subclasses must implement extract()")
248+
249+
current_input = response # Initial input
250+
251+
for stage in self._pipeline:
252+
stage_success = (
253+
False # Track if any strategy in the stage succeeds
254+
)
255+
256+
for strategy in stage:
257+
try:
258+
# Apply the extraction timeout
259+
result = await asyncio.wait_for(
260+
strategy.extract(current_input),
261+
timeout=self._metadata["extraction_timeout"],
262+
)
263+
264+
if result is not None:
265+
current_input = result # Feed into next stage
266+
stage_success = True
267+
break # Move to next stage if valid extraction occurs
268+
269+
except asyncio.TimeoutError:
270+
logger.warning(
271+
f"Strategy {strategy.__class__.__name__} timed out "
272+
f"after {self._metadata['extraction_timeout']} seconds"
273+
)
274+
except Exception as e:
275+
logger.warning(
276+
f"Strategy {strategy.__class__.__name__} failed: {e}"
277+
)
278+
279+
if not stage_success:
280+
logger.debug(
281+
"No strategy in stage succeeded, stopping extraction."
282+
)
283+
return None # Stop processing if the stage fails
284+
285+
return current_input # Final processed output

0 commit comments

Comments
 (0)