Skip to content

Commit fffa88b

Browse files
authored
Merge pull request assafelovic#432 from norrisp90/master
Hotfix for assafelovic#422 & # 427 : LLM provider = azureopenai when using detailed report mode
2 parents dc2e3a1 + 5642cce commit fffa88b

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

gpt_researcher/llm_provider/azureopenai/azureopenai.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
from langchain_openai import AzureChatOpenAI
55

66
'''
7-
Please note: Needs additional env vars such as:
8-
AZURE_OPENAI_ENDPOINT e.g. https://xxxx.openai.azure.com/",
9-
OPENAI_API_VERSION,
10-
OPENAI_API_TYPE
7+
Please note:
8+
Needs additional env vars such as:
9+
AZURE_OPENAI_ENDPOINT e.g. https://xxxx.openai.azure.com/",
10+
AZURE_OPENAI_API_KEY e.g "xxxxxxxxxxxxxxxxxxxxx",
11+
OPENAI_API_VERSION, e.g. "2024-03-01-preview" but needs to updated over time as API verison updates,
12+
AZURE_EMBEDDING_MODEL e.g. "ada2" The Azure OpenAI embedding model deployment name.
1113
12-
Note new entry in config.py to specify the Azure OpenAI embedding model name:
13-
self.azure_embedding_model = os.getenv('AZURE_EMBEDDING_MODEL', "INSERT_EMBEDDIGN_MODEL_DEPLOYMENT_NAME")
14+
config.py settings for Azure OpenAI should look like:
15+
self.embedding_provider = os.getenv('EMBEDDING_PROVIDER', 'azureopenai')
16+
self.llm_provider = os.getenv('LLM_PROVIDER', "azureopenai")
17+
self.fast_llm_model = os.getenv('FAST_LLM_MODEL', "gpt-3.5-turbo-16k") #Deployment name of your GPT3.5T model as per azure OpenAI studio deployment section
18+
self.smart_llm_model = os.getenv('SMART_LLM_MODEL', "gpt4") #Deployment name of your GPT4 1106-Preview+ (GPT4T) model as per azure OpenAI studio deployment section
1419
'''
1520
class AzureOpenAIProvider:
1621

gpt_researcher/utils/llm.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,13 @@ async def construct_subtopics(task: str, data: str, config, subtopics: list = []
123123

124124
print(f"\n🤖 Calling {config.smart_llm_model}...\n")
125125

126-
model = ChatOpenAI(model=config.smart_llm_model)
126+
if config.llm_provider == "openai":
127+
model = ChatOpenAI(model=config.smart_llm_model)
128+
elif config.llm_provider == "azureopenai":
129+
from langchain_openai import AzureChatOpenAI
130+
model = AzureChatOpenAI(model=config.smart_llm_model)
131+
else:
132+
return []
127133

128134
chain = prompt | model | parser
129135

@@ -138,4 +144,4 @@ async def construct_subtopics(task: str, data: str, config, subtopics: list = []
138144

139145
except Exception as e:
140146
print("Exception in parsing subtopics : ", e)
141-
return subtopics
147+
return subtopics

0 commit comments

Comments
 (0)