@@ -2,51 +2,66 @@ import {
2
2
CompletionHandler ,
3
3
IInlineCompletionContext
4
4
} from '@jupyterlab/completer' ;
5
- import { BaseLLM } from '@langchain/core/language_models/llms' ;
6
- import { OpenAI } from '@langchain/openai' ;
5
+ import { BaseChatModel } from '@langchain/core/language_models/chat_models' ;
6
+ import { AIMessage , SystemMessage } from '@langchain/core/messages' ;
7
+ import { ChatOpenAI } from '@langchain/openai' ;
7
8
8
9
import { BaseCompleter , IBaseCompleter } from './base-completer' ;
10
+ import { COMPLETION_SYSTEM_PROMPT } from '../provider' ;
9
11
10
12
export class OpenAICompleter implements IBaseCompleter {
11
13
constructor ( options : BaseCompleter . IOptions ) {
12
- this . _gptProvider = new OpenAI ( { ...options . settings } ) ;
14
+ this . _gptProvider = new ChatOpenAI ( { ...options . settings } ) ;
13
15
}
14
16
15
- get provider ( ) : BaseLLM {
17
+ get provider ( ) : BaseChatModel {
16
18
return this . _gptProvider ;
17
19
}
18
20
21
+ /**
22
+ * Getter and setter for the initial prompt.
23
+ */
24
+ get prompt ( ) : string {
25
+ return this . _prompt ;
26
+ }
27
+ set prompt ( value : string ) {
28
+ this . _prompt = value ;
29
+ }
30
+
19
31
async fetch (
20
32
request : CompletionHandler . IRequest ,
21
33
context : IInlineCompletionContext
22
34
) {
23
35
const { text, offset : cursorOffset } = request ;
24
36
const prompt = text . slice ( 0 , cursorOffset ) ;
25
- const suffix = text . slice ( cursorOffset ) ;
26
-
27
- const data = {
28
- prompt,
29
- suffix,
30
- model : this . _gptProvider . model ,
31
- // temperature: 0,
32
- // top_p: 1,
33
- // max_tokens: 1024,
34
- // min_tokens: 0,
35
- // random_seed: 1337,
36
- stop : [ ]
37
- } ;
37
+
38
+ const messages = [ new SystemMessage ( this . _prompt ) , new AIMessage ( prompt ) ] ;
38
39
39
40
try {
40
- const response = await this . _gptProvider . completionWithRetry ( data , { } ) ;
41
- const items = response . choices . map ( ( choice : any ) => {
42
- return { insertText : choice . message . content as string } ;
43
- } ) ;
41
+ const response = await this . _gptProvider . invoke ( messages ) ;
42
+ const items = [ ] ;
43
+ if ( typeof response . content === 'string' ) {
44
+ items . push ( {
45
+ insertText : response . content
46
+ } ) ;
47
+ } else {
48
+ response . content . forEach ( content => {
49
+ if ( content . type !== 'text' ) {
50
+ return ;
51
+ }
52
+ items . push ( {
53
+ insertText : content . text ,
54
+ filterText : prompt . substring ( prompt . length )
55
+ } ) ;
56
+ } ) ;
57
+ }
44
58
return items ;
45
59
} catch ( error ) {
46
60
console . error ( 'Error fetching completions' , error ) ;
47
61
return { items : [ ] } ;
48
62
}
49
63
}
50
64
51
- private _gptProvider : OpenAI ;
65
+ private _gptProvider : ChatOpenAI ;
66
+ private _prompt : string = COMPLETION_SYSTEM_PROMPT ;
52
67
}
0 commit comments