Skip to content

Commit 9587418

Browse files
committed
Refactored providers into their own folder. Also added support Google Gemini models
1 parent 12df54e commit 9587418

File tree

12 files changed

+519
-64
lines changed

12 files changed

+519
-64
lines changed

config.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"selenium_web_browser": "chrome",
66
"search_api": "tavily",
77
"embedding_provider": "openai",
8-
"llm_provider": "ChatOpenAI",
8+
"llm_provider": "openai",
99
"fast_llm_model": "gpt-3.5-turbo-16k",
1010
"smart_llm_model": "gpt-4",
1111
"fast_token_limit": 2000,

docs/docs/gpt-researcher/config.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Here is an example of the default config.py file found in `/gpt_researcher/confi
2222
def __init__(self, config_file: str = None):
2323
self.config_file = config_file
2424
self.retriever = "tavily"
25-
self.llm_provider = "ChatOpenAI"
25+
self.llm_provider = "openai"
2626
self.fast_llm_model = "gpt-3.5-turbo-16k"
2727
self.smart_llm_model = "gpt-4-1106-preview"
2828
self.fast_token_limit = 2000
@@ -42,7 +42,7 @@ def __init__(self, config_file: str = None):
4242

4343
Please note that you can also include your own external JSON file by adding the path in the `config_file` param.
4444

45-
To learn more about additional LLM support you can check out the [Langchain Adapter](https://python.langchain.com/docs/guides/adapters/openai) and [Langchain supported LLMs](https://python.langchain.com/docs/integrations/llms/) documentation. Simply pass different model names in the `llm_provider` config param.
45+
To learn more about additional LLM support you can check out the [Langchain Adapter](https://python.langchain.com/docs/guides/adapters/openai) and [Langchain supported LLMs](https://python.langchain.com/docs/integrations/llms/) documentation. Simply pass different provider names in the `llm_provider` config param.
4646

4747
You can also change the search engine by modifying the `retriever` param to others such as `duckduckgo`, `googleAPI`, `googleSerp`, `searx` and more.
4848

gpt_researcher/config/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, config_file: str = None):
1111
self.config_file = config_file if config_file else os.getenv('CONFIG_FILE')
1212
self.retriever = os.getenv('SEARCH_RETRIEVER', "tavily")
1313
self.embedding_provider = os.getenv('EMBEDDING_PROVIDER', 'openai')
14-
self.llm_provider = os.getenv('LLM_PROVIDER', "ChatOpenAI")
14+
self.llm_provider = os.getenv('LLM_PROVIDER', "openai")
1515
self.fast_llm_model = os.getenv('FAST_LLM_MODEL', "gpt-3.5-turbo-16k")
1616
self.smart_llm_model = os.getenv('SMART_LLM_MODEL', "gpt-4-1106-preview")
1717
self.fast_token_limit = int(os.getenv('FAST_TOKEN_LIMIT', 2000))
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .google.google import GoogleProvider
2+
from .openai.openai import OpenAIProvider
3+
4+
__all__ = [
5+
"GoogleProvider",
6+
"OpenAIProvider"
7+
]

gpt_researcher/llm_provider/google/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
3+
from colorama import Fore, Style
4+
from langchain_core.messages import HumanMessage, SystemMessage
5+
from langchain_google_genai import ChatGoogleGenerativeAI
6+
7+
8+
class GoogleProvider:
9+
10+
def __init__(
11+
self,
12+
model,
13+
temperature,
14+
max_tokens
15+
):
16+
# May be extended to support more google models in the future
17+
self.model = "gemini-pro"
18+
self.temperature = temperature
19+
self.max_tokens = max_tokens
20+
self.api_key = self.get_api_key()
21+
self.llm = self.get_llm_model()
22+
23+
def get_api_key(self):
24+
"""
25+
Gets the GEMINI_API_KEY
26+
Returns:
27+
28+
"""
29+
try:
30+
api_key = os.environ["GEMINI_API_KEY"]
31+
except:
32+
raise Exception(
33+
"GEMINI API key not found. Please set the GEMINI_API_KEY environment variable.")
34+
return api_key
35+
36+
def get_llm_model(self):
37+
# Initializing the chat model
38+
llm = ChatGoogleGenerativeAI(
39+
convert_system_message_to_human=True,
40+
model=self.model,
41+
temperature=self.temperature,
42+
max_output_tokens=self.max_tokens,
43+
google_api_key=self.api_key
44+
)
45+
46+
return llm
47+
48+
def convert_messages(self, messages):
49+
"""
50+
The function `convert_messages` converts messages based on their role into either SystemMessage
51+
or HumanMessage objects.
52+
53+
Args:
54+
messages: It looks like the code snippet you provided is a function called `convert_messages`
55+
that takes a list of messages as input and converts each message based on its role into either a
56+
`SystemMessage` or a `HumanMessage`.
57+
58+
Returns:
59+
The `convert_messages` function is returning a list of converted messages based on the input
60+
`messages`. The function checks the role of each message in the input list and creates a new
61+
`SystemMessage` object if the role is "system" or a new `HumanMessage` object if the role is
62+
"user". The function then returns a list of these converted messages.
63+
"""
64+
converted_messages = []
65+
for message in messages:
66+
if message["role"] == "system":
67+
converted_messages.append(
68+
SystemMessage(content=message["content"]))
69+
elif message["role"] == "user":
70+
converted_messages.append(
71+
HumanMessage(content=message["content"]))
72+
73+
return converted_messages
74+
75+
async def get_chat_response(self, messages, stream, websocket=None):
76+
if not stream:
77+
# Getting output from the model chain using ainvoke for asynchronous invoking
78+
converted_messages = self.convert_messages(messages)
79+
output = await self.llm.ainvoke(converted_messages)
80+
81+
return output.content
82+
83+
else:
84+
return await self.stream_response(messages, websocket)
85+
86+
async def stream_response(self, messages, websocket=None):
87+
paragraph = ""
88+
response = ""
89+
90+
# Streaming the response using the chain astream method from langchain
91+
async for chunk in self.llm.astream(messages):
92+
content = chunk.content
93+
if content is not None:
94+
response += content
95+
paragraph += content
96+
if "\n" in paragraph:
97+
if websocket is not None:
98+
await websocket.send_json({"type": "report", "output": paragraph})
99+
else:
100+
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
101+
paragraph = ""
102+
103+
return response

gpt_researcher/llm_provider/openai/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import os
2+
3+
from colorama import Fore, Style
4+
from langchain_openai import ChatOpenAI
5+
6+
7+
class OpenAIProvider:
8+
9+
def __init__(
10+
self,
11+
model,
12+
temperature,
13+
max_tokens
14+
):
15+
self.model = model
16+
self.temperature = temperature
17+
self.max_tokens = max_tokens
18+
self.api_key = self.get_api_key()
19+
self.llm = self.get_llm_model()
20+
21+
def get_api_key(self):
22+
"""
23+
Gets the OpenAI API key
24+
Returns:
25+
26+
"""
27+
try:
28+
api_key = os.environ["OPENAI_API_KEY"]
29+
except:
30+
raise Exception(
31+
"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
32+
return api_key
33+
34+
def get_llm_model(self):
35+
# Initializing the chat model
36+
llm = ChatOpenAI(
37+
model=self.model,
38+
temperature=self.temperature,
39+
max_tokens=self.max_tokens,
40+
api_key=self.api_key
41+
)
42+
43+
return llm
44+
45+
async def get_chat_response(self, messages, stream, websocket=None):
46+
if not stream:
47+
# Getting output from the model chain using ainvoke for asynchronous invoking
48+
output = await self.llm.ainvoke(messages)
49+
50+
return output.content
51+
52+
else:
53+
return await self.stream_response(messages, websocket)
54+
55+
async def stream_response(self, messages, websocket=None):
56+
paragraph = ""
57+
response = ""
58+
59+
# Streaming the response using the chain astream method from langchain
60+
async for chunk in self.llm.astream(messages):
61+
content = chunk.content
62+
if content is not None:
63+
response += content
64+
paragraph += content
65+
if "\n" in paragraph:
66+
if websocket is not None:
67+
await websocket.send_json({"type": "report", "output": paragraph})
68+
else:
69+
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
70+
paragraph = ""
71+
72+
return response

gpt_researcher/utils/llm.py

+33-58
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,31 @@
11
# libraries
22
from __future__ import annotations
3+
import logging
4+
35
import json
4-
from fastapi import WebSocket
5-
from colorama import Fore, Style
66
from typing import Optional
7-
from langchain_openai import ChatOpenAI
7+
8+
from colorama import Fore, Style
9+
from fastapi import WebSocket
10+
811
from gpt_researcher.master.prompts import auto_agent_instructions
912

1013

14+
def get_provider(llm_provider):
15+
match llm_provider:
16+
case "openai":
17+
from ..llm_provider import OpenAIProvider
18+
llm_provider = OpenAIProvider
19+
case "google":
20+
from ..llm_provider import GoogleProvider
21+
llm_provider = GoogleProvider
22+
23+
case _:
24+
raise Exception("LLM provider not found.")
25+
26+
return llm_provider
27+
28+
1129
async def create_chat_completion(
1230
messages: list, # type: ignore
1331
model: Optional[str] = None,
@@ -34,71 +52,28 @@ async def create_chat_completion(
3452
if model is None:
3553
raise ValueError("Model cannot be None")
3654
if max_tokens is not None and max_tokens > 8001:
37-
raise ValueError(f"Max tokens cannot be more than 8001, but got {max_tokens}")
55+
raise ValueError(
56+
f"Max tokens cannot be more than 8001, but got {max_tokens}")
57+
58+
# Get the provider from supported providers
59+
ProviderClass = get_provider(llm_provider)
60+
provider = ProviderClass(
61+
model,
62+
temperature,
63+
max_tokens
64+
)
3865

3966
# create response
4067
for _ in range(10): # maximum of 10 attempts
41-
response = await send_chat_completion_request(
42-
messages, model, temperature, max_tokens, stream, llm_provider, websocket
68+
response = await provider.get_chat_response(
69+
messages, stream, websocket
4370
)
4471
return response
4572

4673
logging.error("Failed to get response from OpenAI API")
4774
raise RuntimeError("Failed to get response from OpenAI API")
4875

4976

50-
import logging
51-
52-
53-
async def send_chat_completion_request(
54-
messages, model, temperature, max_tokens, stream, llm_provider, websocket=None
55-
):
56-
if not stream:
57-
# Initializing the chat model
58-
chat = ChatOpenAI(
59-
model=model,
60-
temperature=temperature,
61-
max_tokens=max_tokens
62-
)
63-
64-
# Getting output from the model chain using ainvoke for asynchronous invoking
65-
output = await chat.ainvoke(messages)
66-
67-
return output.content
68-
69-
else:
70-
return await stream_response(
71-
model, messages, temperature, max_tokens, llm_provider, websocket
72-
)
73-
74-
75-
async def stream_response(model, messages, temperature, max_tokens, llm_provider, websocket=None):
76-
# Initializing the model
77-
chat = ChatOpenAI(
78-
model=model,
79-
temperature=temperature,
80-
max_tokens=max_tokens
81-
)
82-
83-
paragraph = ""
84-
response = ""
85-
86-
# Streaming the response using the chain astream method from langchain
87-
async for chunk in chat.astream(messages):
88-
content = chunk.content
89-
if content is not None:
90-
response += content
91-
paragraph += content
92-
if "\n" in paragraph:
93-
if websocket is not None:
94-
await websocket.send_json({"type": "report", "output": paragraph})
95-
else:
96-
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
97-
paragraph = ""
98-
99-
return response
100-
101-
10277
def choose_agent(smart_llm_model: str, llm_provider: str, task: str) -> dict:
10378
"""Determines what server should be used
10479
Args:

0 commit comments

Comments
 (0)