forked from microsoft/kernel-memory
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOpenAITextGenerator.cs
158 lines (137 loc) · 5.48 KB
/
OpenAITextGenerator.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Core.Pipeline;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.Diagnostics;
namespace Microsoft.KernelMemory.AI.OpenAI;
public class OpenAITextGenerator : ITextGenerator
{
private readonly ILogger<OpenAITextGenerator> _log;
private readonly ITextTokenizer _textTokenizer;
private readonly OpenAIClient _client;
private readonly bool _isTextModel;
private readonly string _model;
/// <inheritdoc/>
public int MaxTokenTotal { get; }
public OpenAITextGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<OpenAITextGenerator>(), httpClient)
{
}
public OpenAITextGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<OpenAITextGenerator>? log = null,
HttpClient? httpClient = null)
{
var textModels = new List<string>
{
"text-ada-001",
"text-babbage-001",
"text-curie-001",
"text-davinci-001",
"text-davinci-002",
"text-davinci-003",
"gpt-3.5-turbo-instruct"
};
this._log = log ?? DefaultLogger<OpenAITextGenerator>.Instance;
if (textTokenizer == null)
{
this._log.LogWarning(
"Tokenizer not specified, will use {0}. The token count might be incorrect, causing unexpected errors",
nameof(DefaultGPTTokenizer));
textTokenizer = new DefaultGPTTokenizer();
}
this._textTokenizer = textTokenizer;
if (string.IsNullOrEmpty(config.TextModel))
{
throw new ConfigurationException("The OpenAI model name is empty");
}
this._isTextModel = (textModels.Contains(config.TextModel.ToLowerInvariant()));
this._model = config.TextModel;
this.MaxTokenTotal = config.TextModelMaxTokenTotal;
OpenAIClientOptions options = new()
{
RetryPolicy = new RetryPolicy(maxRetries: Math.Max(0, config.MaxRetries), new SequentialDelayStrategy()),
Diagnostics =
{
IsTelemetryEnabled = Telemetry.IsTelemetryEnabled,
ApplicationId = Telemetry.HttpUserAgent,
}
};
if (httpClient is not null)
{
options.Transport = new HttpClientTransport(httpClient);
}
this._client = new OpenAIClient(config.APIKey, options);
}
/// <inheritdoc/>
public int CountTokens(string text)
{
return this._textTokenizer.CountTokens(text);
}
public async IAsyncEnumerable<string> GenerateTextAsync(
string prompt,
TextGenerationOptions options,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (this._isTextModel)
{
var openaiOptions = new CompletionsOptions
{
DeploymentName = this._model,
MaxTokens = options.MaxTokens,
Temperature = (float)options.Temperature,
NucleusSamplingFactor = (float)options.TopP,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
ChoicesPerPrompt = 1,
};
if (options.StopSequences is { Count: > 0 })
{
foreach (var s in options.StopSequences) { openaiOptions.StopSequences.Add(s); }
}
StreamingResponse<Completions>? response = await this._client.GetCompletionsStreamingAsync(openaiOptions, cancellationToken).ConfigureAwait(false);
await foreach (Completions? completions in response.EnumerateValues().WithCancellation(cancellationToken).ConfigureAwait(false))
{
foreach (Choice? choice in completions.Choices)
{
yield return choice.Text;
}
}
}
else
{
var openaiOptions = new ChatCompletionsOptions
{
DeploymentName = this._model,
MaxTokens = options.MaxTokens,
Temperature = (float)options.Temperature,
NucleusSamplingFactor = (float)options.TopP,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
// ChoiceCount = 1,
};
if (options.StopSequences is { Count: > 0 })
{
foreach (var s in options.StopSequences) { openaiOptions.StopSequences.Add(s); }
}
openaiOptions.Messages.Add(new ChatRequestSystemMessage(prompt));
StreamingResponse<StreamingChatCompletionsUpdate>? response = await this._client.GetChatCompletionsStreamingAsync(openaiOptions, cancellationToken).ConfigureAwait(false);
await foreach (StreamingChatCompletionsUpdate? update in response.EnumerateValues().WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return update.ContentUpdate;
}
}
}
}