Skip to content

Commit

Permalink
feat(gateway): to support AzureAIProvider
Browse files Browse the repository at this point in the history
 - refactor createHttpClient
  • Loading branch information
hanrw committed Aug 23, 2024
1 parent cb01e8c commit 9386056
Show file tree
Hide file tree
Showing 22 changed files with 375 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ import com.tddworks.anthropic.api.messages.api.Messages
import com.tddworks.anthropic.api.messages.api.internal.DefaultMessagesApi
import com.tddworks.anthropic.api.messages.api.internal.JsonLenient
import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.di.createJson
import com.tddworks.di.getInstance
import com.tddworks.common.network.api.ktor.internal.*

/**
* Interface for interacting with the Anthropic API.
Expand All @@ -30,9 +27,10 @@ interface Anthropic : Messages {

val requester = HttpRequester.default(
createHttpClient(
host = anthropicConfig.baseUrl,
connectionConfig = UrlBasedConnectionConfig(anthropicConfig.baseUrl),
authConfig = AuthConfig(anthropicConfig.apiKey),
// get from commonModule
json = JsonLenient,
features = ClientFeatures(json = JsonLenient)
)
)
val messages = DefaultMessagesApi(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ import com.tddworks.anthropic.api.messages.api.Messages
import com.tddworks.anthropic.api.messages.api.internal.DefaultMessagesApi
import com.tddworks.anthropic.api.messages.api.internal.JsonLenient
import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.common.network.api.ktor.internal.*
import com.tddworks.di.commonModule
import kotlinx.serialization.json.Json
import org.koin.core.context.startKoin
import org.koin.core.module.Module
import org.koin.core.qualifier.named
import org.koin.dsl.KoinAppDeclaration
import org.koin.dsl.module
Expand All @@ -38,8 +36,10 @@ fun anthropicModules(
single<HttpRequester>(named("anthropicHttpRequester")) {
HttpRequester.default(
createHttpClient(
host = config.baseUrl,
json = get(named("anthropicJson")),
connectionConfig = UrlBasedConnectionConfig(config.baseUrl),
authConfig = AuthConfig(config.apiKey),
// get from commonModule
features = ClientFeatures(json = get(named("anthropicJson")))
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,122 +3,114 @@ package com.tddworks.common.network.api.ktor.internal
import io.ktor.client.*
import io.ktor.client.engine.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.auth.*
import io.ktor.client.plugins.auth.providers.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.logging.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.*
import io.ktor.util.*
import kotlinx.serialization.json.Json
import kotlin.time.Duration.Companion.seconds


internal expect fun httpClientEngine(): HttpClientEngine

/**
* Creates a new [HttpClient] with [OkHttp] engine and [ContentNegotiation] plugin.
*
* @param protocol the protocol to use - default is HTTPS
* @param host the base URL of the API
* @param port the port to use - default is 443
* @param authToken the authentication token
* @return a new [HttpClient] instance
*/

interface ConnectionConfig {
fun setupUrl(builder: DefaultRequest.DefaultRequestBuilder) {
builder.setupUrl(this)
}
}

data class UrlBasedConnectionConfig(
val baseUrl: () -> String = { "" }
) : ConnectionConfig

data class HostPortConnectionConfig(
val protocol: () -> String? = { null },
val host: () -> String = { "" },
val port: () -> Int? = { null },
) : ConnectionConfig

data class AuthConfig(
val authToken: (() -> String)? = null
)

data class ClientFeatures(
val json: Json = Json,
val queryParams: Map<String, String> = emptyMap(),
val expectSuccess: Boolean = true
)

fun createHttpClient(
protocol: () -> String? = { null },
host: () -> String,
port: () -> Int? = { null },
authToken: (() -> String)? = null,
json: Json = Json,
httpClientEngine: HttpClientEngine = httpClientEngine(),
connectionConfig: ConnectionConfig = UrlBasedConnectionConfig(),
authConfig: AuthConfig = AuthConfig(),
features: ClientFeatures = ClientFeatures(),
httpClientEngine: HttpClientEngine = httpClientEngine()
): HttpClient {
return HttpClient(httpClientEngine) {
// enable proxy in the future
// engine {
// proxy = ProxyBuilder.http(url)
// }

return HttpClient(httpClientEngine) {
install(ContentNegotiation) {
register(ContentType.Application.Json, KotlinxSerializationConverter(json))
register(
ContentType.Application.Json,
KotlinxSerializationConverter(features.json)
)
}

/**
* Support configurable in the future
* Install the Logging module.
* @param logging the logging instance to use
* @return Unit
*/
install(Logging) {
/**
* DEFAULT - default - LoggerFactory.getLogger
* SIMPLE - Logger using println.
* Empty - Empty Logger for test purpose.
*/
logger = Logger.DEFAULT
/**
* ALL - log all
* HEADERS - log headers
* INFO - log info
* NONE - none
*/
level = LogLevel.INFO
}

/**
* Install the Auth module. but can't update on the fly
* @param auth the auth instance to use
* @return Unit
*/
// authToken?.let {
// install(Auth) {
// bearer {
// loadTokens {
// BearerTokens(accessToken = authToken(), refreshToken = "")
// }
// }
// }
// }

/**
* Installs an [HttpRequestRetry] with default maxRetries of 3,
* retryIf checks for rate limit error with status code 429,
* and exponential delay with base 5.0 and max delay of 1 minute.
*
* @param retry [HttpRequestRetry] instance to install
*/
install(HttpRequestRetry) {
maxRetries = 3
// retry on rate limit error.
retryIf { _, response -> response.status.value.let { it == 429 } }
exponentialDelay(base = 5.0, maxDelayMs = 10.seconds.inWholeMilliseconds)
retryIf { _, response -> response.status.value == 429 }
exponentialDelay(base = 5.0, maxDelayMs = 60_000)
}


defaultRequest {
url {
this.protocol = protocol()?.let { URLProtocol.createOrDefault(it) }
?: URLProtocol.HTTPS
this.host = host()
port()?.let { this.port = it }
}
connectionConfig.setupUrl(this)
commonSettings(features.queryParams, authConfig.authToken)
}

expectSuccess = features.expectSuccess
}
}

authToken?.let {
header(HttpHeaders.Authorization, "Bearer ${it()}")
private fun DefaultRequest.DefaultRequestBuilder.setupUrl(connectionConfig: ConnectionConfig) {
when (connectionConfig) {
is HostPortConnectionConfig -> {
url {
protocol =
connectionConfig.protocol()?.let { URLProtocol.createOrDefault(it) }
?: URLProtocol.HTTPS
host = connectionConfig.host()
connectionConfig.port()?.let { port = it }
}
}

header(HttpHeaders.ContentType, ContentType.Application.Json)
contentType(ContentType.Application.Json)
is UrlBasedConnectionConfig -> {
connectionConfig.baseUrl().let { url.takeFrom(it) }
}
}
}

/**
* If set to true, the client will throw an exception if the response from the server is not successful. The definition of successful can vary depending on the HTTP status code. For example, a successful response for a GET request would typically be a status code of 200, while a successful response for a POST request could be a status code of 201.
*
* By setting expectSuccess = true, the developer is indicating that they want to handle non-successful responses explicitly and can throw or handle the exceptions themselves.
*
* If expectSuccess is set to false, the HttpClient will not throw exceptions for non-successful responses and the developer is responsible for parsing and handling any errors or unexpected responses.
*/
expectSuccess = true
private fun DefaultRequest.DefaultRequestBuilder.commonSettings(
queryParams: Map<String, String>,
authToken: (() -> String)?
) {
queryParams.forEach { (key, value) ->
url.parameters.appendIfNameAbsent(
key,
value
)
}

authToken?.let {
header(HttpHeaders.Authorization, "Bearer ${it()}")
}

header(HttpHeaders.ContentType, ContentType.Application.Json)
contentType(ContentType.Application.Json)
}

Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,52 @@ import io.ktor.client.engine.okhttp.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.utils.io.*
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.Json
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test

class HttpClientTest {

@Test
fun `should return correct json response with default settings`() {
runBlocking {
val mockEngine = MockEngine { request ->
respond(
content = ByteReadChannel("""{"ip":"127.0.0.1"}"""),
status = HttpStatusCode.OK,
headers = headersOf(HttpHeaders.ContentType, "application/json")
)
}
val apiClient = createHttpClient(
host = { "some-host" },
httpClientEngine = mockEngine
fun `should return correct response with host and port based config`() = runTest {
val mockEngine = MockEngine { _ ->
respond(
content = ByteReadChannel("""{"ip":"127.0.0.1"}"""),
status = HttpStatusCode.OK,
headers = headersOf(HttpHeaders.ContentType, "application/json")
)

val body = apiClient.get("https://some-host:443").body<String>()
assertEquals("""{"ip":"127.0.0.1"}""", body)
}
val apiClient = createHttpClient(
connectionConfig = HostPortConnectionConfig(
protocol = { "https" },
host = { "some-host" },
port = { 443 }
),
httpClientEngine = mockEngine
)

val body = apiClient.get("https://some-host").body<String>()
assertEquals("""{"ip":"127.0.0.1"}""", body)
}

@Test
fun `should return correct response with url based config`() = runTest {
val mockEngine = MockEngine { _ ->
respond(
content = ByteReadChannel("""{"ip":"127.0.0.1"}"""),
status = HttpStatusCode.OK,
headers = headersOf(HttpHeaders.ContentType, "application/json")
)
}
val apiClient = createHttpClient(
connectionConfig = UrlBasedConnectionConfig { "https://some-host" },
httpClientEngine = mockEngine
)

val body = apiClient.get("https://some-host").body<String>()
assertEquals("""{"ip":"127.0.0.1"}""", body)
}

@Test
fun `should return OkHttp engine`() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package com.tddworks.ollama.api

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.common.network.api.ktor.internal.*
import com.tddworks.ollama.api.chat.OllamaChat
import com.tddworks.ollama.api.chat.internal.DefaultOllamaChatApi
import com.tddworks.ollama.api.generate.OllamaGenerate
Expand All @@ -16,18 +15,16 @@ import com.tddworks.ollama.api.json.JsonLenient
interface Ollama : OllamaChat, OllamaGenerate {

companion object {
const val BASE_URL = "localhost"
const val PORT = 11434
const val PROTOCOL = "http"
const val BASE_URL = "http://localhost:11434"

fun create(ollamaConfig: OllamaConfig): Ollama {

val requester = HttpRequester.default(
createHttpClient(
host = ollamaConfig.baseUrl,
port = ollamaConfig.port,
protocol = ollamaConfig.protocol,
json = JsonLenient,
connectionConfig = UrlBasedConnectionConfig(
baseUrl = ollamaConfig.baseUrl,
),
// get from commonModule
features = ClientFeatures(json = JsonLenient)
)
)
val ollamaChat = DefaultOllamaChatApi(requester = requester)
Expand All @@ -47,18 +44,4 @@ interface Ollama : OllamaChat, OllamaGenerate {
* @return a string representing the base URL
*/
fun baseUrl(): String

/**
* This function returns the port as an integer.
*
* @return an integer representing the port
*/
fun port(): Int

/**
* This function returns the protocol as a string.
*
* @return a string representing the protocol
*/
fun protocol(): String
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,5 @@ package com.tddworks.ollama.api
import org.koin.core.component.KoinComponent

data class OllamaConfig(
val baseUrl: () -> String = { Ollama.BASE_URL },
val protocol: () -> String = { Ollama.PROTOCOL },
val port: () -> Int = { Ollama.PORT },
val baseUrl: () -> String = { Ollama.BASE_URL }
) : KoinComponent
Loading

0 comments on commit 9386056

Please sign in to comment.