Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate doc filtration and reasoning LLMs #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ Initialize the client by providing required configuration options:
```typescript
client = new ReagClient(
model: "o3-mini", // LiteLLM model name
filtration_model: "minimax", // Filtration model name
system: Optional[str] // Optional system prompt
batchSize: Optional[Number] // Optional batch size
schema: Optional[BaseModel] // Optional Pydantic schema
50 changes: 36 additions & 14 deletions python/src/reag/client.py
Original file line number Diff line number Diff line change
@@ -39,11 +39,13 @@ class ReagClient:
def __init__(
self,
model: str = "gpt-4o-mini",
filtration_model: str = "minimax",
system: str = None,
batch_size: int = DEFAULT_BATCH_SIZE,
schema: Optional[BaseModel] = None,
):
self.model = model
self.filtration_model = filtration_model
self.system = system or REAG_SYSTEM_PROMPT
self.batch_size = batch_size
self.schema = schema or ResponseSchemaMessage
@@ -161,46 +163,66 @@ def format_doc(doc: Document) -> str:
f"{self.system}\n\n# Available source\n\n{format_doc(document)}"
)

# Use litellm for model completion with the Pydantic schema
response = await acompletion(
model=self.model,
# Use litellm for model completion with the filtration model
filtration_response = await acompletion(
model=self.filtration_model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
response_format=self.schema,
)

message_content = response.choices[
filtration_message_content = filtration_response.choices[
0
].message.content # Might be a JSON string

try:
# Ensure it's parsed as a dict
data = (
json.loads(message_content)
if isinstance(message_content, str)
else message_content
filtration_data = (
json.loads(filtration_message_content)
if isinstance(filtration_message_content, str)
else filtration_message_content
)

if data["source"].get("is_irrelevant", True):
if filtration_data["source"].get("is_irrelevant", True):
continue

# Use litellm for model completion with the reasoning model
reasoning_response = await acompletion(
model=self.model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
response_format=self.schema,
)

reasoning_message_content = reasoning_response.choices[
0
].message.content # Might be a JSON string

reasoning_data = (
json.loads(reasoning_message_content)
if isinstance(reasoning_message_content, str)
else reasoning_message_content
)

results.append(
QueryResult(
content=data["source"].get("content", ""),
reasoning=data["source"].get("reasoning", ""),
is_irrelevant=data["source"].get(
content=reasoning_data["source"].get("content", ""),
reasoning=reasoning_data["source"].get("reasoning", ""),
is_irrelevant=reasoning_data["source"].get(
"is_irrelevant", False
),
document=document,
)
)
except json.JSONDecodeError:
print("Error: Could not parse response:", message_content)
print("Error: Could not parse response:", filtration_message_content)
continue # Skip this iteration if parsing fails

return results
return results

except Exception as e:
raise Exception(f"Query failed: {str(e)}")
2 changes: 2 additions & 0 deletions typescript/README.md
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ import { openai } from "@ai-sdk/openai";
// Initialize the SDK with required options
const client = new ReagClient({
model: openai("o3-mini", { structuredOutputs: true }),
filtrationModel: openai("minimax", { structuredOutputs: true }),
// system: optional system prompt here or use the default
});

@@ -62,6 +63,7 @@ Initialize the client by providing required configuration options:
```typescript
const client = new ReagClient({
model: openai("o3-mini", { structuredOutputs: true }),
filtrationModel: openai("minimax", { structuredOutputs: true }),
system?: string // Optional system prompt
batchSize?: number // Optional batch size
schema?: z.ZodSchema // Optional schema
30 changes: 27 additions & 3 deletions typescript/src/client.ts
Original file line number Diff line number Diff line change
@@ -16,6 +16,11 @@ export interface ClientOptions {
* See: https://sdk.vercel.ai/docs/foundations/providers-and-models
*/
model: LanguageModel;
/**
* The filtration model instance to use for document filtration.
* This should be an instance of a model that implements the Vercel AI SDK's LanguageModel interface.
*/
filtrationModel: LanguageModel;
/**
* The system prompt that provides context and instructions to the model.
* This string sets the behavior and capabilities of the model for all queries.
@@ -69,6 +74,7 @@ const DEFAULT_BATCH_SIZE = 20;
*/
export class ReagClient {
private readonly model: LanguageModel;
private readonly filtrationModel: LanguageModel;
private readonly system: string;
private readonly batchSize: number;
private readonly schema: z.ZodSchema;
@@ -79,6 +85,7 @@ export class ReagClient {
*/
constructor(options: ClientOptions) {
this.model = options.model;
this.filtrationModel = options.filtrationModel;
this.system = options.system || REAG_SYSTEM_PROMPT;
this.batchSize = options.batchSize || DEFAULT_BATCH_SIZE;
this.schema = options.schema || RESPONSE_SCHEMA;
@@ -171,20 +178,37 @@ export class ReagClient {
const system = `${
this.system
}\n\n# Available source\n\n${formatDoc(document)}`;
const response = await generateObject({

// Use the filtration model for document filtration
const filtrationResponse = await generateObject({
model: this.filtrationModel,
system,
prompt,
schema: this.schema,
});

const filtrationData = filtrationResponse.object;

if (filtrationData.isIrrelevant) {
return null;
}

// Use the reasoning model for generating the final answer
const reasoningResponse = await generateObject({
model: this.model,
system,
prompt,
schema: this.schema,
});

return {
response,
response: reasoningResponse,
document,
};
})
);
return batchResponses;

return batchResponses.filter((response) => response !== null);
})
);