Skip to content

Commit

Permalink
feat(gateway): move hard code OllamaChat by getInstance(), OllamaGene…
Browse files Browse the repository at this point in the history
…rate by getInstance() to param
  • Loading branch information
hanrw committed Aug 21, 2024
1 parent 634695f commit becf14d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package com.tddworks.ollama.api

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.ollama.api.chat.OllamaChat
import com.tddworks.ollama.api.chat.internal.DefaultOllamaChatApi
import com.tddworks.ollama.api.generate.OllamaGenerate
import com.tddworks.ollama.api.generate.internal.DefaultOllamaGenerateApi
import com.tddworks.ollama.api.internal.OllamaApi
import com.tddworks.ollama.api.json.JsonLenient

/**
* Interface for interacting with the Ollama API.
Expand All @@ -13,6 +19,26 @@ interface Ollama : OllamaChat, OllamaGenerate {
const val BASE_URL = "localhost"
const val PORT = 11434
const val PROTOCOL = "http"

fun create(ollamaConfig: OllamaConfig): Ollama {

val requester = HttpRequester.default(
createHttpClient(
host = ollamaConfig.baseUrl,
port = ollamaConfig.port,
protocol = ollamaConfig.protocol,
json = JsonLenient,
)
)
val ollamaChat = DefaultOllamaChatApi(requester = requester)
val ollamaGenerate = DefaultOllamaGenerateApi(requester = requester)

return OllamaApi(
config = ollamaConfig,
ollamaChat = ollamaChat,
ollamaGenerate = ollamaGenerate
)
}
}

/**
Expand All @@ -35,16 +61,4 @@ interface Ollama : OllamaChat, OllamaGenerate {
* @return a string representing the protocol
*/
fun protocol(): String
}

fun Ollama(
baseUrl: () -> String = { Ollama.BASE_URL },
port: () -> Int = { Ollama.PORT },
protocol: () -> String = { Ollama.PROTOCOL },
): Ollama {
return OllamaApi(
baseUrl = baseUrl(),
port = port(),
protocol = protocol()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,37 @@ package com.tddworks.ollama.api.internal

import com.tddworks.di.getInstance
import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.OllamaConfig
import com.tddworks.ollama.api.chat.OllamaChat
import com.tddworks.ollama.api.generate.OllamaGenerate

class OllamaApi(
private val baseUrl: String,
private val port: Int,
private val protocol: String,
) : Ollama, OllamaChat by getInstance(), OllamaGenerate by getInstance() {
private val config: OllamaConfig,
private val ollamaChat: OllamaChat,
private val ollamaGenerate: OllamaGenerate
) : Ollama, OllamaChat by ollamaChat, OllamaGenerate by ollamaGenerate {

override fun baseUrl(): String {
return baseUrl
return config.baseUrl()
}

override fun port(): Int {
return port
return config.port()
}

override fun protocol(): String {
return protocol
return config.protocol()
}

}

fun Ollama(
baseUrl: () -> String = { Ollama.BASE_URL },
port: () -> Int = { Ollama.PORT },
protocol: () -> String = { Ollama.PROTOCOL },
): Ollama {
return OllamaApi(
baseUrl = baseUrl(),
port = port(),
protocol = protocol()
)
}

fun Ollama.Companion.create(
baseUrl: () -> String = { BASE_URL },
port: () -> Int = { PORT },
protocol: () -> String = { PROTOCOL },
config: OllamaConfig,
ollamaChat: OllamaChat = getInstance(),
ollamaGenerate: OllamaGenerate = getInstance()
): Ollama {
return OllamaApi(
baseUrl = baseUrl(),
port = port(),
protocol = protocol()
config = config,
ollamaChat = ollamaChat,
ollamaGenerate = ollamaGenerate
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ fun ollamaModules(
) = module {

single<Ollama> {
Ollama(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
)
Ollama.create(ollamaConfig = config)
}

single<Json>(named("ollamaJson")) { JsonLenient }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.OllamaConfig
import com.tddworks.ollama.api.OllamaModel
import com.tddworks.ollama.api.chat.api.*
import com.tddworks.ollama.api.internal.create
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
Expand All @@ -23,9 +23,11 @@ class OllamaOpenAIProvider(
OpenAIModel(it.value)
},
private val client: Ollama = Ollama.create(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
ollamaConfig = OllamaConfig(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
)
)
) : OpenAIProvider {
/**
Expand Down Expand Up @@ -75,9 +77,11 @@ fun OpenAIProvider.Companion.ollama(
OpenAIModel(it.value)
},
client: Ollama = Ollama.create(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
ollamaConfig = OllamaConfig(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
)
)
): OpenAIProvider {
return OllamaOpenAIProvider(config = config, models = models, client = client)
Expand Down

0 comments on commit becf14d

Please sign in to comment.