Skip to content

Commit

Permalink
refactor: remove duplicate logic for DefaultChatApi and AzureChatApi
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Aug 26, 2024
1 parent ebbad1c commit 7207253
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,41 @@ import com.tddworks.common.network.api.ktor.internal.*
import com.tddworks.di.createJson
import com.tddworks.di.getInstance
import com.tddworks.openai.api.chat.api.Chat
import com.tddworks.openai.api.chat.internal.DefaultChatApi
import com.tddworks.openai.api.chat.internal.default
import com.tddworks.openai.api.images.api.Images
import com.tddworks.openai.api.images.internal.DefaultImagesApi
import com.tddworks.openai.api.images.internal.default
import com.tddworks.openai.api.legacy.completions.api.Completions
import com.tddworks.openai.api.legacy.completions.api.internal.DefaultCompletionsApi
import com.tddworks.openai.api.legacy.completions.api.internal.default

interface OpenAI : Chat, Images, Completions {
companion object {
const val BASE_URL = "https://api.openai.com"

fun create(config: OpenAIConfig): OpenAI {
fun default(config: OpenAIConfig): OpenAI {
val requester = HttpRequester.default(
createHttpClient(
connectionConfig = UrlBasedConnectionConfig(config.baseUrl),
authConfig = AuthConfig(config.apiKey),
features = ClientFeatures(json = createJson())
)
)
return create(requester)
return default(requester)
}

fun create(
fun default(
requester: HttpRequester,
chatCompletionPath: String = Chat.CHAT_COMPLETIONS_PATH
): OpenAI {
val chatApi = DefaultChatApi(
val chatApi = Chat.default(
requester = requester,
chatCompletionPath = chatCompletionPath
)

val imagesApi = DefaultImagesApi(
val imagesApi = Images.default(
requester = requester
)

val completionsApi = DefaultCompletionsApi(
val completionsApi = Completions.default(
requester = requester
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ import kotlinx.serialization.ExperimentalSerializationApi
* @property requester The HttpRequester to use for performing HTTP requests.
*/
@OptIn(ExperimentalSerializationApi::class)
class DefaultChatApi(
internal class DefaultChatApi(
private val requester: HttpRequester,
private val chatCompletionPath: String = CHAT_COMPLETIONS_PATH
private val chatCompletionPath: String = CHAT_COMPLETIONS_PATH,
private val extraHeaders: Map<String, String> = mapOf()
) : Chat {
override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
return requester.performRequest<ChatCompletion> {
method = HttpMethod.Post
url(path = chatCompletionPath)
setBody(request)
contentType(ContentType.Application.Json)
headers {
extraHeaders.forEach { (key, value) -> append(key, value) }
}
}
}

Expand All @@ -42,8 +46,14 @@ class DefaultChatApi(
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
extraHeaders.forEach { (key, value) -> append(key, value) }
}
}
}
}

}
fun Chat.Companion.default(
requester: HttpRequester,
chatCompletionPath: String = CHAT_COMPLETIONS_PATH,
extraHeaders: Map<String, String> = mapOf()
): Chat = DefaultChatApi(requester, chatCompletionPath, extraHeaders)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import io.ktor.client.request.*
import io.ktor.http.*
import kotlinx.serialization.ExperimentalSerializationApi

class DefaultImagesApi(private val requester: HttpRequester) : Images {
internal class DefaultImagesApi(private val requester: HttpRequester) : Images {
@OptIn(ExperimentalSerializationApi::class)
override suspend fun generate(request: ImageCreate): ListResponse<Image> {
return requester.performRequest<ListResponse<Image>> {
Expand All @@ -20,4 +20,7 @@ class DefaultImagesApi(private val requester: HttpRequester) : Images {
contentType(ContentType.Application.Json)
}
}
}
}

fun Images.Companion.default(requester: HttpRequester): Images =
DefaultImagesApi(requester)
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ package com.tddworks.openai.api.legacy.completions.api
*/
interface Completions {
suspend fun completions(request: CompletionRequest): Completion
companion object
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.tddworks.openai.api.legacy.completions.api.Completions
import io.ktor.client.request.*
import io.ktor.http.*

class DefaultCompletionsApi(
internal class DefaultCompletionsApi(
private val requester: HttpRequester,
) : Completions {
override suspend fun completions(request: CompletionRequest): Completion {
Expand All @@ -23,4 +23,7 @@ class DefaultCompletionsApi(
companion object {
const val COMPLETIONS_PATH = "/v1/completions"
}
}
}

fun Completions.Companion.default(requester: HttpRequester): Completions =
DefaultCompletionsApi(requester)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class OpenAITest {

@Test
fun `should create openai instance`() {
val openAI = OpenAI.create(OpenAIConfig())
val openAI = OpenAI.default(OpenAIConfig())

assertNotNull(openAI)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,20 @@
package com.tddworks.azure.api

import com.tddworks.azure.api.internal.AzureChatApi
import com.tddworks.azure.api.internal.azure
import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.api.performRequest
import com.tddworks.common.network.api.ktor.api.streamRequest
import com.tddworks.common.network.api.ktor.internal.ClientFeatures
import com.tddworks.common.network.api.ktor.internal.UrlBasedConnectionConfig
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.di.createJson
import com.tddworks.openai.api.OpenAI
import com.tddworks.openai.api.chat.api.Chat
import com.tddworks.openai.api.chat.api.Chat.Companion.CHAT_COMPLETIONS_PATH
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.images.api.Images
import com.tddworks.openai.api.images.internal.DefaultImagesApi
import com.tddworks.openai.api.images.internal.default
import com.tddworks.openai.api.legacy.completions.api.Completions
import com.tddworks.openai.api.legacy.completions.api.internal.DefaultCompletionsApi
import com.tddworks.openai.api.legacy.completions.api.internal.default
import com.tddworks.openai.gateway.api.OpenAIProviderConfig
import io.ktor.client.request.*
import io.ktor.http.*
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.ExperimentalSerializationApi

data class AzureAIProviderConfig(
override val apiKey: () -> String,
Expand All @@ -43,6 +35,7 @@ fun OpenAIProviderConfig.Companion.azure(
apiVersion = apiVersion
)


/**
* Authentication
* Azure OpenAI provides two methods for authentication. You can use either API Keys or Microsoft Entra ID.
Expand All @@ -63,69 +56,21 @@ fun OpenAI.Companion.azure(config: AzureAIProviderConfig): OpenAI {
)
)
)
return azure(
config = config,
requester = requester,
chatCompletionPath = "chat/completions"
)
}

fun azure(
config: AzureAIProviderConfig,
requester: HttpRequester,
chatCompletionPath: String
): OpenAI {
val chatApi = AzureChatApi(
config = config,
val chatApi = Chat.azure(
apiKey = config.apiKey,
requester = requester,
chatCompletionPath = chatCompletionPath
chatCompletionPath = AzureChatApi.CHAT_COMPLETIONS
)

val imagesApi = DefaultImagesApi(
//TODO implement the rest of the APIs for Azure Images.azure
val imagesApi = Images.default(
requester = requester
)

val completionsApi = DefaultCompletionsApi(
//TODO implement the rest of the APIs for Azure Completions.azure
val completionsApi = Completions.default(
requester = requester
)

return object : OpenAI, Chat by chatApi, Images by imagesApi,
Completions by completionsApi {}
}

@OptIn(ExperimentalSerializationApi::class)
class AzureChatApi(
private val config: AzureAIProviderConfig,
private val requester: HttpRequester,
private val chatCompletionPath: String = CHAT_COMPLETIONS_PATH
) : Chat {

companion object {
const val BASE_URL = "https://YOUR_RESOURCE_NAME.openai.azure.com"
}

override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
return requester.performRequest<ChatCompletion> {
method = HttpMethod.Post
url(path = chatCompletionPath)
setBody(request)
contentType(ContentType.Application.Json)
}
}

override fun streamChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
return requester.streamRequest<ChatCompletionChunk> {
method = HttpMethod.Post
url(path = chatCompletionPath)
setBody(request.copy(stream = true))
contentType(ContentType.Application.Json)
accept(ContentType.Text.EventStream)
headers {
append("api-key", config.apiKey())
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.tddworks.azure.api.internal

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.openai.api.chat.api.Chat
import com.tddworks.openai.api.chat.internal.default

internal class AzureChatApi(
private val chatCompletionPath: String,
private val requester: HttpRequester,
private val extraHeaders: Map<String, String> = mapOf(),
private val chatApi: Chat = Chat.default(
requester = requester,
chatCompletionPath = chatCompletionPath,
extraHeaders = extraHeaders
)
) : Chat by chatApi {
companion object {
const val BASE_URL = "https://YOUR_RESOURCE_NAME.openai.azure.com"
const val CHAT_COMPLETIONS = "chat/completions"
}
}

fun Chat.Companion.azure(
apiKey: () -> String,
requester: HttpRequester,
chatCompletionPath: String = AzureChatApi.CHAT_COMPLETIONS,
): Chat = AzureChatApi(
requester = requester,
chatCompletionPath = chatCompletionPath,
extraHeaders = mapOf("api-key" to apiKey())
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

package com.tddworks.openai.gateway.api.internal

import com.tddworks.azure.api.AzureAIProviderConfig
import com.tddworks.azure.api.azure
import com.tddworks.common.network.api.ktor.api.ListResponse
import com.tddworks.openai.api.OpenAI
import com.tddworks.openai.api.chat.api.ChatCompletion
Expand All @@ -23,7 +25,7 @@ class DefaultOpenAIProvider(
override val name: String = "OpenAI",
override val models: List<OpenAIModel> = availableModels,
override val config: OpenAIProviderConfig,
private val openAI: OpenAI = OpenAI.create(config.toOpenAIConfig()),
private val openAI: OpenAI = OpenAI.default(config.toOpenAIConfig()),
) : OpenAIProvider {

override fun supports(model: OpenAIModel): Boolean {
Expand Down Expand Up @@ -51,7 +53,19 @@ fun OpenAIProvider.Companion.openAI(
id: String = "openai",
config: OpenAIProviderConfig,
models: List<OpenAIModel>,
openAI: OpenAI = OpenAI.create(config.toOpenAIConfig())
openAI: OpenAI = OpenAI.default(config.toOpenAIConfig())
): OpenAIProvider {
return DefaultOpenAIProvider(
id = id,
config = config, models = models, openAI = openAI
)
}

fun OpenAIProvider.Companion.azure(
id: String = "azure",
config: OpenAIProviderConfig,
models: List<OpenAIModel>,
openAI: OpenAI = OpenAI.azure(config as AzureAIProviderConfig)
): OpenAIProvider {
return DefaultOpenAIProvider(
id = id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class OllamaOpenAIProvider(
) : OpenAIProvider {
/**
* Check if the given OpenAIModel is supported by the available models.
* @param openAIModel The OpenAIModel to check for support.
* @param model The OpenAIModel to check for support.
* @return true if the model is supported, false otherwise.
*/
override fun supports(openAIModel: OpenAIModel): Boolean {
return models.any { it.value == openAIModel.value }
override fun supports(model: OpenAIModel): Boolean {
return models.any { it.value == model.value }
}

/**
Expand All @@ -46,9 +46,7 @@ class OllamaOpenAIProvider(
*/
override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
val ollamaChatRequest = request.toOllamaChatRequest()
return client.request(ollamaChatRequest).let {
it.toOpenAIChatCompletion()
}
return client.request(ollamaChatRequest).toOpenAIChatCompletion()
}

/**
Expand All @@ -65,9 +63,7 @@ class OllamaOpenAIProvider(
}

override suspend fun completions(request: CompletionRequest): Completion {
return client.request(request.toOllamaGenerateRequest()).let {
it.toOpenAICompletion()
}
return client.request(request.toOllamaGenerateRequest()).toOpenAICompletion()
}

override suspend fun generate(request: ImageCreate): ListResponse<Image> {
Expand Down

0 comments on commit 7207253

Please sign in to comment.