Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAI provider #19

Merged
merged 8 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"@langchain/community": "^0.3.31",
"@langchain/core": "^0.3.40",
"@langchain/mistralai": "^0.1.1",
"@langchain/openai": "^0.4.4",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2",
"@lumino/signaling": "^2.1.2",
Expand Down
2 changes: 1 addition & 1 deletion schema/ai-provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"title": "The AI provider",
"description": "The AI provider to use for chat and completion",
"default": "None",
"enum": ["None", "Anthropic", "ChromeAI", "MistralAI"]
"enum": ["None", "Anthropic", "ChromeAI", "MistralAI", "OpenAI"]
}
},
"additionalProperties": true
Expand Down
44 changes: 26 additions & 18 deletions scripts/settings-generator.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ const providers = {
path: 'node_modules/@langchain/anthropic/dist/chat_models.d.ts',
type: 'AnthropicInput',
excludedProps: ['clientOptions']
},
openAI: {
path: 'node_modules/@langchain/openai/dist/chat_models.d.ts',
type: 'ChatOpenAIFields',
excludedProps: ['configuration']
}
};

Expand All @@ -53,7 +58,8 @@ Object.entries(providers).forEach(([name, desc], index) => {
path: desc.path,
tsconfig: './tsconfig.json',
type: desc.type,
functions: 'hide'
functions: 'hide',
topRef: false
};

const outputPath = path.join(outputDir, `${name}.json`);
Expand Down Expand Up @@ -81,33 +87,35 @@ Object.entries(providers).forEach(([name, desc], index) => {
}

// Remove the properties from extended class.
const providerKeys = Object.keys(schema.definitions[desc.type]['properties']);

const providerKeys = Object.keys(schema.properties);
Object.keys(
schemaBase.definitions?.['BaseLanguageModelParams']['properties']
).forEach(key => {
if (providerKeys.includes(key)) {
delete schema.definitions?.[desc.type]['properties'][key];
delete schema.properties?.[key];
}
});

// Remove the useless definitions.
let change = true;
while (change) {
change = false;
const temporarySchemaString = JSON.stringify(schema);

Object.keys(schema.definitions).forEach(key => {
const index = temporarySchemaString.indexOf(`#/definitions/${key}`);
if (index === -1) {
delete schema.definitions?.[key];
change = true;
}
});
// Replace all references by their value, and remove the useless definitions.
const defKeys = Object.keys(schema.definitions);
for (let i = defKeys.length - 1; i >= 0; i--) {
let schemaString = JSON.stringify(schema);
const key = defKeys[i];
const reference = `"$ref":"#/definitions/${key}"`;

// Replace all the references to the definition by the content (after removal of the brace).
const replacement = JSON.stringify(schema.definitions?.[key]).slice(1, -1);
temporarySchemaString = schemaString.replaceAll(reference, replacement);
// Build again the schema from the string representation if it change.
if (schemaString !== temporarySchemaString) {
schema = JSON.parse(temporarySchemaString);
}
// Remove the definition
delete schema.definitions?.[key];
}

// Transform the default values.
Object.values(schema.definitions[desc.type]['properties']).forEach(value => {
Object.values(schema.properties).forEach(value => {
const defaultValue = value.default;
if (!defaultValue) {
return;
Expand Down
67 changes: 67 additions & 0 deletions src/llm-models/openai-completer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, SystemMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';

import { BaseCompleter, IBaseCompleter } from './base-completer';
import { COMPLETION_SYSTEM_PROMPT } from '../provider';

export class OpenAICompleter implements IBaseCompleter {
constructor(options: BaseCompleter.IOptions) {
this._openAIProvider = new ChatOpenAI({ ...options.settings });
}

get provider(): BaseChatModel {
return this._openAIProvider;
}

/**
* Getter and setter for the initial prompt.
*/
get prompt(): string {
return this._prompt;
}
set prompt(value: string) {
this._prompt = value;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
) {
const { text, offset: cursorOffset } = request;
const prompt = text.slice(0, cursorOffset);

const messages = [new SystemMessage(this._prompt), new AIMessage(prompt)];

try {
const response = await this._openAIProvider.invoke(messages);
const items = [];
if (typeof response.content === 'string') {
items.push({
insertText: response.content
});
} else {
response.content.forEach(content => {
if (content.type !== 'text') {
return;
}
items.push({
insertText: content.text,
filterText: prompt.substring(prompt.length)
});
});
}
return { items };
} catch (error) {
console.error('Error fetching completions', error);
return { items: [] };
}
}

private _openAIProvider: ChatOpenAI;
private _prompt: string = COMPLETION_SYSTEM_PROMPT;
}
20 changes: 15 additions & 5 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ import { ChatAnthropic } from '@langchain/anthropic';
import { ChromeAI } from '@langchain/community/experimental/llms/chrome_ai';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { JSONObject } from '@lumino/coreutils';
import { ChatOpenAI } from '@langchain/openai';

import { IBaseCompleter } from './base-completer';
import { AnthropicCompleter } from './anthropic-completer';
import { CodestralCompleter } from './codestral-completer';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
import { ChromeCompleter } from './chrome-completer';
import { OpenAICompleter } from './openai-completer';

import chromeAI from '../_provider-settings/chromeAI.json';
import mistralAI from '../_provider-settings/mistralAI.json';
import anthropic from '../_provider-settings/anthropic.json';
import openAI from '../_provider-settings/openAI.json';

/**
* Get an LLM completer from the name.
Expand All @@ -27,6 +29,8 @@ export function getCompleter(
return new AnthropicCompleter({ settings });
} else if (name === 'ChromeAI') {
return new ChromeCompleter({ settings });
} else if (name === 'OpenAI') {
return new OpenAICompleter({ settings });
}
return null;
}
Expand All @@ -46,6 +50,8 @@ export function getChatModel(
// TODO: fix
// @ts-expect-error: missing properties
return new ChromeAI({ ...settings });
} else if (name === 'OpenAI') {
return new ChatOpenAI({ ...settings });
}
return null;
}
Expand All @@ -60,20 +66,24 @@ export function getErrorMessage(name: string, error: any): string {
return error.error.error.message;
} else if (name === 'ChromeAI') {
return error.message;
} else if (name === 'OpenAI') {
return error.message;
}
return 'Unknown provider';
}

/*
* Get an LLM completer from the name.
*/
export function getSettings(name: string): JSONObject | null {
export function getSettings(name: string): any {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this any is fine for now, but we should maybe check if we can put JSONObject | null back as the return type.

if (name === 'MistralAI') {
return mistralAI.definitions.ChatMistralAIInput.properties;
return mistralAI.properties;
} else if (name === 'Anthropic') {
return anthropic.definitions.AnthropicInput.properties;
return anthropic.properties;
} else if (name === 'ChromeAI') {
return chromeAI.definitions.ChromeAIInputs.properties;
return chromeAI.properties;
} else if (name === 'OpenAI') {
return openAI.properties;
}

return null;
Expand Down
3 changes: 2 additions & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,7 @@ __metadata:
"@langchain/community": ^0.3.31
"@langchain/core": ^0.3.40
"@langchain/mistralai": ^0.1.1
"@langchain/openai": ^0.4.4
"@lumino/coreutils": ^2.1.2
"@lumino/polling": ^2.1.2
"@lumino/signaling": ^2.1.2
Expand Down Expand Up @@ -2449,7 +2450,7 @@ __metadata:
languageName: node
linkType: hard

"@langchain/openai@npm:>=0.2.0 <0.5.0":
"@langchain/openai@npm:>=0.2.0 <0.5.0, @langchain/openai@npm:^0.4.4":
version: 0.4.4
resolution: "@langchain/openai@npm:0.4.4"
dependencies:
Expand Down
Loading