From 3c2ee8cae4474b35c46ee00216aab1aa561c90d5 Mon Sep 17 00:00:00 2001 From: Pengfei Ni Date: Sun, 21 Jul 2024 23:38:19 +0800 Subject: [PATCH] Add Google Generative AI models --- package.json | 28 +++----- src/chatgpt-view-provider.ts | 29 +++++--- src/extension.ts | 7 +- src/gemini.ts | 26 +++++++ src/openai.ts | 133 +++++++---------------------------- yarn.lock | 8 +++ 6 files changed, 89 insertions(+), 142 deletions(-) create mode 100644 src/gemini.ts diff --git a/package.json b/package.json index e1c7d1a..d635721 100644 --- a/package.json +++ b/package.json @@ -244,24 +244,6 @@ "configuration": { "title": "ChatGPT", "properties": { - "chatgpt.method": { - "type": "string", - "enum": [ - "GPT3 OpenAI API Key", - "Claude 3" - ], - "default": "GPT3 OpenAI API Key", - "markdownDescription": "Choose your integration preference.", - "order": 1, - "enumItemLabels": [ - "Use OpenAI API key integration", - "Use Claude 3 integration" - ], - "markdownEnumDescriptions": [ - "Various chat & text completion models are supported including OpenAI and OpenAI API compatible local models.", - "Claude 3 models from Anthropic" - ] - }, "chatgpt.systemPrompt": { "type": "string", "default": "", @@ -424,6 +406,9 @@ "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620", + "gemini-1.5-flash-latest", + "gemini-1.5-pro-latest", + "gemini-1.0-pro-latest", "custom" ], "default": "gpt-3.5-turbo", @@ -459,6 +444,9 @@ "Claude 3 - claude-3-sonnet-20240229", "Claude 3 - claude-3-haiku-20240307", "Claude 3 - claude-3-5-sonnet-20240620", + "Google Gemini - gemini-1.5-flash-latest", + "Google Gemini - gemini-1.5-pro-latest", + "Google Gemini - gemini-1.0-pro-latest", "Custom Model" ], "markdownEnumDescriptions": [ @@ -491,6 +479,9 @@ "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620", + "gemini-1.5-flash-latest", + "gemini-1.5-pro-latest", + "gemini-1.0-pro-latest", "Custom model name set by `chatgpt.gpt3.customModel`" ] }, @@ -572,6 +563,7 @@ "dependencies": { "@ai-sdk/anthropic": "^0.0.30", "@ai-sdk/azure": "^0.0.14", + "@ai-sdk/google": "^0.0.27", "@ai-sdk/openai": "^0.0.37", "@types/minimatch": "^5.1.2", "ai": "^3.2.32", diff --git a/src/chatgpt-view-provider.ts b/src/chatgpt-view-provider.ts index 471544e..1c9a0cc 100644 --- a/src/chatgpt-view-provider.ts +++ b/src/chatgpt-view-provider.ts @@ -19,6 +19,7 @@ import { CoreMessage } from "ai"; import delay from "delay"; import * as vscode from "vscode"; import { initClaudeModel } from "./anthropic"; +import { initGeminiModel } from "./gemini"; import { ModelConfig } from "./model-config"; import { chatGpt, initGptModel } from "./openai"; import { chatCompletion, initGptLegacyModel } from "./openai-legacy"; @@ -216,13 +217,17 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider { } private get isGpt35Model(): boolean { - return !this.isCodexModel && !this.isClaude; + return !this.isCodexModel && !this.isClaude && !this.isGemini; } private get isClaude(): boolean { return !!this.model?.startsWith("claude-"); } + private get isGemini(): boolean { + return !!this.model?.startsWith("gemini-"); + } + public async prepareConversation(modelChanged = false): Promise { this.conversationId = this.conversationId || this.getRandomId(); const state = this.context.globalState; @@ -235,7 +240,8 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider { if ( (this.isGpt35Model && !this.apiChat) || (this.isClaude && !this.apiChat) || - (!this.isGpt35Model && !this.isClaude && !this.apiCompletion) || + (this.isGemini && !this.apiChat) || + (!this.isGpt35Model && !this.isClaude && !this.isGemini && !this.apiCompletion) || modelChanged ) { let apiKey = @@ -252,11 +258,14 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider { } let apiBaseUrl = configuration.get("gpt3.apiBaseUrl") as string; - if (!apiBaseUrl) { - if (this.isGpt35Model) { - apiBaseUrl = "https://api.openai.com/v1"; - } else if (this.isClaude) { - apiBaseUrl = "https://api.anthropic.com"; + if (!apiBaseUrl && this.isGpt35Model) { + apiBaseUrl = "https://api.openai.com/v1"; + } + if (!apiBaseUrl || apiBaseUrl == "https://api.openai.com/v1") { + if (this.isClaude) { + apiBaseUrl = "https://api.anthropic.com/v1"; + } else if (this.isGemini) { + apiBaseUrl = "https://generativelanguage.googleapis.com/v1beta"; } } @@ -314,6 +323,8 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider { await initGptModel(this, this.modelConfig); } else if (this.isClaude) { await initClaudeModel(this, this.modelConfig); + } else if (this.isGemini) { + await initGeminiModel(this, this.modelConfig); } else { initGptLegacyModel(this, this.modelConfig); } @@ -400,9 +411,7 @@ export default class ChatGptViewProvider implements vscode.WebviewViewProvider { }); }; try { - if (this.isGpt35Model) { - await chatGpt(this, question, updateResponse); - } else if (this.isClaude) { + if (this.isGpt35Model || this.isClaude || this.isGemini) { await chatGpt(this, question, updateResponse); } else { await chatCompletion(this, question, updateResponse); diff --git a/src/extension.ts b/src/extension.ts index 7093278..9d23a8b 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -127,8 +127,7 @@ export async function activate(context: vscode.ExtensionContext) { if ( e.affectsConfiguration("chatgpt.promptPrefix") || e.affectsConfiguration("chatgpt.gpt3.generateCode-enabled") || - e.affectsConfiguration("chatgpt.gpt3.model") || - e.affectsConfiguration("chatgpt.method") + e.affectsConfiguration("chatgpt.gpt3.model") ) { setContext(); } @@ -243,12 +242,8 @@ export async function activate(context: vscode.ExtensionContext) { const modelName = vscode.workspace .getConfiguration("chatgpt") .get("gpt3.model") as string; - const method = vscode.workspace - .getConfiguration("chatgpt") - .get("method") as string; generateCodeEnabled = generateCodeEnabled && - method === "GPT3 OpenAI API Key" && modelName.startsWith("code-"); vscode.commands.executeCommand( "setContext", diff --git a/src/gemini.ts b/src/gemini.ts new file mode 100644 index 0000000..c029cb2 --- /dev/null +++ b/src/gemini.ts @@ -0,0 +1,26 @@ +/* eslint-disable eqeqeq */ +/* eslint-disable @typescript-eslint/naming-convention */ +/** + * @author Pengfei Ni + * + * @license + * Copyright (c) 2024 - Present, Pengfei Ni + * + * All rights reserved. Code licensed under the ISC license + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. +*/ +import { createGoogleGenerativeAI } from '@ai-sdk/google'; +import ChatGptViewProvider from "./chatgpt-view-provider"; +import { ModelConfig } from "./model-config"; + +// initGeminiModel initializes the Gemini model with the given parameters. +export async function initGeminiModel(viewProvider: ChatGptViewProvider, config: ModelConfig) { + const gemini = createGoogleGenerativeAI({ + baseURL: config.apiBaseUrl, + apiKey: config.apiKey, + }); + const model = viewProvider.model ? viewProvider.model : "gemini-1.5-flash-latest"; + viewProvider.apiChat = gemini("models/" + model); +} diff --git a/src/openai.ts b/src/openai.ts index d6dfc47..927907f 100644 --- a/src/openai.ts +++ b/src/openai.ts @@ -19,30 +19,6 @@ import { ModelConfig } from "./model-config"; // initGptModel initializes the GPT model. export async function initGptModel(viewProvider: ChatGptViewProvider, config: ModelConfig) { - // let tools: Tool[] = [ - // new WikipediaQueryRun({ - // topKResults: 3, - // maxDocContentLength: 4000, - // }) - // ]; - // if (config.googleCSEApiKey != "" && config.googleCSEId != "") { - // tools.push(new GoogleCustomSearch({ - // apiKey: config.googleCSEApiKey, - // googleCSEId: config.googleCSEId, - // })); - // } - // if (config.serperKey != "") { - // tools.push(new Serper(config.serperKey)); - // } - // if (config.bingKey != "") { - // tools.push(new BingSerpAPI(config.bingKey)); - // } - - // let embeddings = new OpenAIEmbeddings({ - // modelName: "text-embedding-ada-002", - // openAIApiKey: config.apiKey, - // }); - // AzureOpenAI if (config.apiBaseUrl?.includes("azure")) { const instanceName = config.apiBaseUrl.split(".")[0].split("//")[1]; @@ -54,27 +30,6 @@ export async function initGptModel(viewProvider: ChatGptViewProvider, config: Mo apiKey: config.apiKey, }); viewProvider.apiChat = azure.chat(deployName); - - // embeddings = new OpenAIEmbeddings({ - // azureOpenAIApiEmbeddingsDeploymentName: "text-embedding-ada-002", - // azureOpenAIApiKey: config.apiKey, - // azureOpenAIApiInstanceName: instanceName, - // azureOpenAIApiDeploymentName: deployName, - // azureOpenAIApiCompletionsDeploymentName: deployName, - // azureOpenAIApiVersion: "2024-02-01", - // }); - // viewProvider.apiChat = new ChatOpenAI({ - // modelName: viewProvider.model, - // azureOpenAIApiKey: config.apiKey, - // azureOpenAIApiInstanceName: instanceName, - // azureOpenAIApiDeploymentName: deployName, - // azureOpenAIApiCompletionsDeploymentName: deployName, - // azureOpenAIApiVersion: "2024-02-01", - // maxTokens: config.maxTokens, - // streaming: true, - // temperature: config.temperature, - // topP: config.topP, - // }); } else { // OpenAI const openai = createOpenAI({ @@ -83,50 +38,7 @@ export async function initGptModel(viewProvider: ChatGptViewProvider, config: Mo organization: config.organization, }); viewProvider.apiChat = openai.chat(viewProvider.model ? viewProvider.model : "gpt-4o"); - - // viewProvider.apiChat = new ChatOpenAI({ - // openAIApiKey: config.apiKey, - // modelName: viewProvider.model, - // maxTokens: config.maxTokens, - // streaming: true, - // temperature: config.temperature, - // topP: config.topP, - // configuration: { - // apiKey: config.apiKey, - // baseURL: config.apiBaseUrl, - // organization: config.organization, - // }, - // }); } - - // if (config.apiBaseUrl == "https://api.openai.com/v1") { - // tools.push(new WebBrowser({ - // model: viewProvider.apiChat, - // embeddings: embeddings, - // })); - // } - - - // const chatPrompt = ChatPromptTemplatePackage.fromMessages([ - // SystemMessagePromptTemplate.fromTemplate(systemContext), - // new MessagesPlaceholder("chat_history"), - // HumanMessagePromptTemplate.fromTemplate("{input}"), - // new MessagesPlaceholder("agent_scratchpad"), - // ]); - // const agent = await createOpenAIFunctionsAgent({ - // llm: viewProvider.apiChat, - // tools: tools, - // prompt: chatPrompt, - // }); - - // const agentExecutor = new AgentExecutor({ agent, tools }); - // viewProvider.tools = tools; - // viewProvider.chain = new RunnableWithMessageHistory({ - // runnable: agentExecutor, - // getMessageHistory: (_sessionId) => config.messageHistory, - // inputMessagesKey: "input", - // historyMessagesKey: "chat_history", - // }); } // chatGpt is a function that completes the chat. @@ -135,26 +47,31 @@ export async function chatGpt(provider: ChatGptViewProvider, question: string, u throw new Error("apiChat is undefined"); } - logger.appendLine(`INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question}`); - provider.chatHistory.push({ role: "user", content: question }); + try { + logger.appendLine(`INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question}`); + provider.chatHistory.push({ role: "user", content: question }); - const chunks = []; - const result = await streamText({ - system: provider.modelConfig.systemPrompt, - model: provider.apiChat, - messages: provider.chatHistory, - maxTokens: provider.modelConfig.maxTokens, - topP: provider.modelConfig.topP, - temperature: provider.modelConfig.temperature, - }); - for await (const textPart of result.textStream) { - // logger.appendLine( - // `INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question} response: ${JSON.stringify(textPart, null, 2)}` - // ); - updateResponse(textPart); - chunks.push(textPart); + const chunks = []; + const result = await streamText({ + system: provider.modelConfig.systemPrompt, + model: provider.apiChat, + messages: provider.chatHistory, + maxTokens: provider.modelConfig.maxTokens, + topP: provider.modelConfig.topP, + temperature: provider.modelConfig.temperature, + }); + for await (const textPart of result.textStream) { + // logger.appendLine( + // `INFO: chatgpt.model: ${provider.model} chatgpt.question: ${question} response: ${JSON.stringify(textPart, null, 2)}` + // ); + updateResponse(textPart); + chunks.push(textPart); + } + provider.response = chunks.join(""); + provider.chatHistory.push({ role: "assistant", content: chunks.join("") }); + logger.appendLine(`INFO: chatgpt.response: ${provider.response}`); + } catch (error) { + logger.appendLine(`ERROR: chatgpt.model: ${provider.model} response: ${error}`); + throw error; } - provider.response = chunks.join(""); - provider.chatHistory.push({ role: "assistant", content: chunks.join("") }); - logger.appendLine(`INFO: chatgpt.response: ${provider.response}`); } diff --git a/yarn.lock b/yarn.lock index 908db2f..c0f48a1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -19,6 +19,14 @@ "@ai-sdk/provider" "0.0.12" "@ai-sdk/provider-utils" "1.0.2" +"@ai-sdk/google@^0.0.27": + version "0.0.27" + resolved "https://registry.yarnpkg.com/@ai-sdk/google/-/google-0.0.27.tgz#f30c41f2ace51f149edd2a1e7daa6907fb3a6775" + integrity sha512-HVkTrHq0TA0ynPZk3UBdhIoiAASQSILNS+qIpHbKISlEgxihlbcjO9qneaM8shl4LQL69eH6cM6YTgzAsx22mA== + dependencies: + "@ai-sdk/provider" "0.0.12" + "@ai-sdk/provider-utils" "1.0.2" + "@ai-sdk/openai@0.0.37", "@ai-sdk/openai@^0.0.37": version "0.0.37" resolved "https://registry.npmjs.org/@ai-sdk/openai/-/openai-0.0.37.tgz"