Skip to content

Commit

Permalink
Add Google Generative AI models
Browse files Browse the repository at this point in the history
  • Loading branch information
feiskyer committed Jul 21, 2024
1 parent 8972ac4 commit 3c2ee8c
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 142 deletions.
28 changes: 10 additions & 18 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -244,24 +244,6 @@
"configuration": {
"title": "ChatGPT",
"properties": {
"chatgpt.method": {
"type": "string",
"enum": [
"GPT3 OpenAI API Key",
"Claude 3"
],
"default": "GPT3 OpenAI API Key",
"markdownDescription": "Choose your integration preference.",
"order": 1,
"enumItemLabels": [
"Use OpenAI API key integration",
"Use Claude 3 integration"
],
"markdownEnumDescriptions": [
"Various chat & text completion models are supported including OpenAI and OpenAI API compatible local models.",
"Claude 3 models from Anthropic"
]
},
"chatgpt.systemPrompt": {
"type": "string",
"default": "",
Expand Down Expand Up @@ -424,6 +406,9 @@
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-latest",
"gemini-1.0-pro-latest",
"custom"
],
"default": "gpt-3.5-turbo",
Expand Down Expand Up @@ -459,6 +444,9 @@
"Claude 3 - claude-3-sonnet-20240229",
"Claude 3 - claude-3-haiku-20240307",
"Claude 3 - claude-3-5-sonnet-20240620",
"Google Gemini - gemini-1.5-flash-latest",
"Google Gemini - gemini-1.5-pro-latest",
"Google Gemini - gemini-1.0-pro-latest",
"Custom Model"
],
"markdownEnumDescriptions": [
Expand Down Expand Up @@ -491,6 +479,9 @@
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-latest",
"gemini-1.0-pro-latest",
"Custom model name set by `chatgpt.gpt3.customModel`"
]
},
Expand Down Expand Up @@ -572,6 +563,7 @@
"dependencies": {
"@ai-sdk/anthropic": "^0.0.30",
"@ai-sdk/azure": "^0.0.14",
"@ai-sdk/google": "^0.0.27",
"@ai-sdk/openai": "^0.0.37",
"@types/minimatch": "^5.1.2",
"ai": "^3.2.32",
Expand Down
29 changes: 19 additions & 10 deletions src/chatgpt-view-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { CoreMessage } from "ai";
import delay from "delay";
import * as vscode from "vscode";
import { initClaudeModel } from "./anthropic";
import { initGeminiModel } from "./gemini";
import { ModelConfig } from "./model-config";
import { chatGpt, initGptModel } from "./openai";
import { chatCompletion, initGptLegacyModel } from "./openai-legacy";
Expand Down Expand Up @@ -216,13 +217,17 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider {
}

private get isGpt35Model(): boolean {
return !this.isCodexModel && !this.isClaude;
return !this.isCodexModel && !this.isClaude && !this.isGemini;
}

private get isClaude(): boolean {
return !!this.model?.startsWith("claude-");
}

private get isGemini(): boolean {
return !!this.model?.startsWith("gemini-");
}

public async prepareConversation(modelChanged = false): Promise<boolean> {
this.conversationId = this.conversationId || this.getRandomId();
const state = this.context.globalState;
Expand All @@ -235,7 +240,8 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider {
if (
(this.isGpt35Model && !this.apiChat) ||
(this.isClaude && !this.apiChat) ||
(!this.isGpt35Model && !this.isClaude && !this.apiCompletion) ||
(this.isGemini && !this.apiChat) ||
(!this.isGpt35Model && !this.isClaude && !this.isGemini && !this.apiCompletion) ||
modelChanged
) {
let apiKey =
Expand All @@ -252,11 +258,14 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider {
}

let apiBaseUrl = configuration.get("gpt3.apiBaseUrl") as string;
if (!apiBaseUrl) {
if (this.isGpt35Model) {
apiBaseUrl = "https://api.openai.com/v1";
} else if (this.isClaude) {
apiBaseUrl = "https://api.anthropic.com";
if (!apiBaseUrl && this.isGpt35Model) {
apiBaseUrl = "https://api.openai.com/v1";
}
if (!apiBaseUrl || apiBaseUrl == "https://api.openai.com/v1") {
if (this.isClaude) {
apiBaseUrl = "https://api.anthropic.com/v1";
} else if (this.isGemini) {
apiBaseUrl = "https://generativelanguage.googleapis.com/v1beta";
}
}

Expand Down Expand Up @@ -314,6 +323,8 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider {
await initGptModel(this, this.modelConfig);
} else if (this.isClaude) {
await initClaudeModel(this, this.modelConfig);
} else if (this.isGemini) {
await initGeminiModel(this, this.modelConfig);
} else {
initGptLegacyModel(this, this.modelConfig);
}
Expand Down Expand Up @@ -400,9 +411,7 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider {
});
};
try {
if (this.isGpt35Model) {
await chatGpt(this, question, updateResponse);
} else if (this.isClaude) {
if (this.isGpt35Model || this.isClaude || this.isGemini) {
await chatGpt(this, question, updateResponse);
} else {
await chatCompletion(this, question, updateResponse);
Expand Down
7 changes: 1 addition & 6 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ export async function activate(context: vscode.ExtensionContext) {
if (
e.affectsConfiguration("chatgpt.promptPrefix") ||
e.affectsConfiguration("chatgpt.gpt3.generateCode-enabled") ||
e.affectsConfiguration("chatgpt.gpt3.model") ||
e.affectsConfiguration("chatgpt.method")
e.affectsConfiguration("chatgpt.gpt3.model")
) {
setContext();
}
Expand Down Expand Up @@ -243,12 +242,8 @@ export async function activate(context: vscode.ExtensionContext) {
const modelName = vscode.workspace
.getConfiguration("chatgpt")
.get("gpt3.model") as string;
const method = vscode.workspace
.getConfiguration("chatgpt")
.get("method") as string;
generateCodeEnabled =
generateCodeEnabled &&
method === "GPT3 OpenAI API Key" &&
modelName.startsWith("code-");
vscode.commands.executeCommand(
"setContext",
Expand Down
26 changes: 26 additions & 0 deletions src/gemini.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* eslint-disable eqeqeq */
/* eslint-disable @typescript-eslint/naming-convention */
/**
* @author Pengfei Ni
*
* @license
* Copyright (c) 2024 - Present, Pengfei Ni
*
* All rights reserved. Code licensed under the ISC license
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*/
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import ChatGptViewProvider from "./chatgpt-view-provider";
import { ModelConfig } from "./model-config";

// initGeminiModel initializes the Gemini model with the given parameters.
export async function initGeminiModel(viewProvider: ChatGptViewProvider, config: ModelConfig) {
const gemini = createGoogleGenerativeAI({
baseURL: config.apiBaseUrl,
apiKey: config.apiKey,
});
const model = viewProvider.model ? viewProvider.model : "gemini-1.5-flash-latest";
viewProvider.apiChat = gemini("models/" + model);
}
133 changes: 25 additions & 108 deletions src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,6 @@ import { ModelConfig } from "./model-config";

// initGptModel initializes the GPT model.
export async function initGptModel(viewProvider: ChatGptViewProvider, config: ModelConfig) {
// let tools: Tool[] = [
// new WikipediaQueryRun({
// topKResults: 3,
// maxDocContentLength: 4000,
// })
// ];
// if (config.googleCSEApiKey != "" && config.googleCSEId != "") {
// tools.push(new GoogleCustomSearch({
// apiKey: config.googleCSEApiKey,
// googleCSEId: config.googleCSEId,
// }));
// }
// if (config.serperKey != "") {
// tools.push(new Serper(config.serperKey));
// }
// if (config.bingKey != "") {
// tools.push(new BingSerpAPI(config.bingKey));
// }

// let embeddings = new OpenAIEmbeddings({
// modelName: "text-embedding-ada-002",
// openAIApiKey: config.apiKey,
// });

// AzureOpenAI
if (config.apiBaseUrl?.includes("azure")) {
const instanceName = config.apiBaseUrl.split(".")[0].split("//")[1];
Expand All @@ -54,27 +30,6 @@ export async function initGptModel(viewProvider: ChatGptViewProvider, config: Mo
apiKey: config.apiKey,
});
viewProvider.apiChat = azure.chat(deployName);

// embeddings = new OpenAIEmbeddings({
// azureOpenAIApiEmbeddingsDeploymentName: "text-embedding-ada-002",
// azureOpenAIApiKey: config.apiKey,
// azureOpenAIApiInstanceName: instanceName,
// azureOpenAIApiDeploymentName: deployName,
// azureOpenAIApiCompletionsDeploymentName: deployName,
// azureOpenAIApiVersion: "2024-02-01",
// });
// viewProvider.apiChat = new ChatOpenAI({
// modelName: viewProvider.model,
// azureOpenAIApiKey: config.apiKey,
// azureOpenAIApiInstanceName: instanceName,
// azureOpenAIApiDeploymentName: deployName,
// azureOpenAIApiCompletionsDeploymentName: deployName,
// azureOpenAIApiVersion: "2024-02-01",
// maxTokens: config.maxTokens,
// streaming: true,
// temperature: config.temperature,
// topP: config.topP,
// });
} else {
// OpenAI
const openai = createOpenAI({
Expand All @@ -83,50 +38,7 @@ export async function initGptModel(viewProvider: ChatGptViewProvider, config: Mo
organization: config.organization,
});
viewProvider.apiChat = openai.chat(viewProvider.model ? viewProvider.model : "gpt-4o");

// viewProvider.apiChat = new ChatOpenAI({
// openAIApiKey: config.apiKey,
// modelName: viewProvider.model,
// maxTokens: config.maxTokens,
// streaming: true,
// temperature: config.temperature,
// topP: config.topP,
// configuration: {
// apiKey: config.apiKey,
// baseURL: config.apiBaseUrl,
// organization: config.organization,
// },
// });
}

// if (config.apiBaseUrl == "https://api.openai.com/v1") {
// tools.push(new WebBrowser({
// model: viewProvider.apiChat,
// embeddings: embeddings,
// }));
// }


// const chatPrompt = ChatPromptTemplatePackage.fromMessages([
// SystemMessagePromptTemplate.fromTemplate(systemContext),
// new MessagesPlaceholder("chat_history"),
// HumanMessagePromptTemplate.fromTemplate("{input}"),
// new MessagesPlaceholder("agent_scratchpad"),
// ]);
// const agent = await createOpenAIFunctionsAgent({
// llm: viewProvider.apiChat,
// tools: tools,
// prompt: chatPrompt,
// });

// const agentExecutor = new AgentExecutor({ agent, tools });
// viewProvider.tools = tools;
// viewProvider.chain = new RunnableWithMessageHistory({
// runnable: agentExecutor,
// getMessageHistory: (_sessionId) => config.messageHistory,
// inputMessagesKey: "input",
// historyMessagesKey: "chat_history",
// });
}

// chatGpt is a function that completes the chat.
Expand All @@ -135,26 +47,31 @@ export async function chatGpt(provider: ChatGptViewProvider, question: string, u
throw new Error("apiChat is undefined");
}

logger.appendLine(`INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question}`);
provider.chatHistory.push({ role: "user", content: question });
try {
logger.appendLine(`INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question}`);
provider.chatHistory.push({ role: "user", content: question });

const chunks = [];
const result = await streamText({
system: provider.modelConfig.systemPrompt,
model: provider.apiChat,
messages: provider.chatHistory,
maxTokens: provider.modelConfig.maxTokens,
topP: provider.modelConfig.topP,
temperature: provider.modelConfig.temperature,
});
for await (const textPart of result.textStream) {
// logger.appendLine(
// `INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question} response: ${JSON.stringify(textPart, null, 2)}`
// );
updateResponse(textPart);
chunks.push(textPart);
const chunks = [];
const result = await streamText({
system: provider.modelConfig.systemPrompt,
model: provider.apiChat,
messages: provider.chatHistory,
maxTokens: provider.modelConfig.maxTokens,
topP: provider.modelConfig.topP,
temperature: provider.modelConfig.temperature,
});
for await (const textPart of result.textStream) {
// logger.appendLine(
// `INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question} response: ${JSON.stringify(textPart, null, 2)}`
// );
updateResponse(textPart);
chunks.push(textPart);
}
provider.response = chunks.join("");
provider.chatHistory.push({ role: "assistant", content: chunks.join("") });
logger.appendLine(`INFO: chatgpt.response: ${provider.response}`);
} catch (error) {
logger.appendLine(`ERROR: chatgpt.model: ${provider.model} response: ${error}`);
throw error;
}
provider.response = chunks.join("");
provider.chatHistory.push({ role: "assistant", content: chunks.join("") });
logger.appendLine(`INFO: chatgpt.response: ${provider.response}`);
}
8 changes: 8 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
"@ai-sdk/provider" "0.0.12"
"@ai-sdk/provider-utils" "1.0.2"

"@ai-sdk/google@^0.0.27":
version "0.0.27"
resolved "https://registry.yarnpkg.com/@ai-sdk/google/-/google-0.0.27.tgz#f30c41f2ace51f149edd2a1e7daa6907fb3a6775"
integrity sha512-HVkTrHq0TA0ynPZk3UBdhIoiAASQSILNS+qIpHbKISlEgxihlbcjO9qneaM8shl4LQL69eH6cM6YTgzAsx22mA==
dependencies:
"@ai-sdk/provider" "0.0.12"
"@ai-sdk/provider-utils" "1.0.2"

"@ai-sdk/[email protected]", "@ai-sdk/openai@^0.0.37":
version "0.0.37"
resolved "https://registry.npmjs.org/@ai-sdk/openai/-/openai-0.0.37.tgz"
Expand Down

0 comments on commit 3c2ee8c

Please sign in to comment.