Skip to content

Commit ac003ce

Browse files
corybcode-asher
authored andcommitted
feat: add configuration options to support mtls
adding options to support mtls with the coder server. This supports adding PEM certs and keys to the tls requests, and also supports adding a CA cert to the trust store. Also allowing for an alternate hostname that may appear in the certs which is useful for testing or for non-standard cert usage.
1 parent 5e55049 commit ac003ce

11 files changed

+327
-37
lines changed

gradle.properties

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ gradleVersion=7.4
2929
# Opt-out flag for bundling Kotlin standard library.
3030
# See https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library for details.
3131
# suppress inspection "UnusedProperty"
32-
kotlin.stdlib.default.dependency=false
32+
kotlin.stdlib.default.dependency=true

src/main/kotlin/com/coder/gateway/CoderGatewayConnectionProvider.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class CoderGatewayConnectionProvider : GatewayConnectionProvider {
140140
if (token == null) { // User aborted.
141141
throw IllegalArgumentException("Unable to connect to $deploymentURL, $TOKEN is missing")
142142
}
143-
val client = CoderRestClient(deploymentURL, token.first, settings.headerCommand, null)
143+
val client = CoderRestClient(deploymentURL, token.first,null, settings)
144144
return try {
145145
Pair(client, client.me().username)
146146
} catch (ex: AuthenticationResponseException) {

src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt

+29-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") {
3939
.comment(
4040
CoderGatewayBundle.message(
4141
"gateway.connector.settings.binary-source.comment",
42-
CoderCLIManager(URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path,
42+
CoderCLIManager(state, URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path,
4343
)
4444
)
4545
}.layout(RowLayout.PARENT_GRID)
@@ -73,6 +73,34 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") {
7373
CoderGatewayBundle.message("gateway.connector.settings.header-command.comment")
7474
)
7575
}.layout(RowLayout.PARENT_GRID)
76+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.title")) {
77+
textField().resizableColumn().align(AlignX.FILL)
78+
.bindText(state::tlsCertPath)
79+
.comment(
80+
CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.comment")
81+
)
82+
}.layout(RowLayout.PARENT_GRID)
83+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.title")) {
84+
textField().resizableColumn().align(AlignX.FILL)
85+
.bindText(state::tlsKeyPath)
86+
.comment(
87+
CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.comment")
88+
)
89+
}.layout(RowLayout.PARENT_GRID)
90+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.title")) {
91+
textField().resizableColumn().align(AlignX.FILL)
92+
.bindText(state::tlsCAPath)
93+
.comment(
94+
CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.comment")
95+
)
96+
}.layout(RowLayout.PARENT_GRID)
97+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.title")) {
98+
textField().resizableColumn().align(AlignX.FILL)
99+
.bindText(state::tlsAlternateHostname)
100+
.comment(
101+
CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.comment")
102+
)
103+
}.layout(RowLayout.PARENT_GRID)
76104
}
77105
}
78106

src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt

+8-2
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ import java.nio.file.StandardCopyOption
2222
import java.security.DigestInputStream
2323
import java.security.MessageDigest
2424
import java.util.zip.GZIPInputStream
25+
import javax.net.ssl.HttpsURLConnection
2526
import javax.xml.bind.annotation.adapters.HexBinaryAdapter
2627

2728

2829
/**
2930
* Manage the CLI for a single deployment.
3031
*/
3132
class CoderCLIManager @JvmOverloads constructor(
33+
private val settings: CoderSettingsState,
3234
private val deploymentURL: URL,
3335
dataDir: Path,
3436
cliDir: Path? = null,
@@ -104,6 +106,10 @@ class CoderCLIManager @JvmOverloads constructor(
104106
conn.setRequestProperty("If-None-Match", "\"$etag\"")
105107
}
106108
conn.setRequestProperty("Accept-Encoding", "gzip")
109+
if (conn is HttpsURLConnection) {
110+
conn.sslSocketFactory = coderSocketFactory(settings)
111+
conn.hostnameVerifier = CoderHostnameVerifier(settings.tlsAlternateHostname)
112+
}
107113

108114
try {
109115
conn.connect()
@@ -463,7 +469,7 @@ class CoderCLIManager @JvmOverloads constructor(
463469
if (settings.binaryDirectory.isBlank()) null
464470
else Path.of(settings.binaryDirectory).toAbsolutePath()
465471

466-
val cli = CoderCLIManager(deploymentURL, dataDir, binDir, settings.binarySource)
472+
val cli = CoderCLIManager(settings, deploymentURL, dataDir, binDir, settings.binarySource)
467473

468474
// Short-circuit if we already have the expected version. This
469475
// lets us bypass the 304 which is slower and may not be
@@ -490,7 +496,7 @@ class CoderCLIManager @JvmOverloads constructor(
490496
}
491497

492498
// Try falling back to the data directory.
493-
val dataCLI = CoderCLIManager(deploymentURL, dataDir, null, settings.binarySource)
499+
val dataCLI = CoderCLIManager(settings, deploymentURL, dataDir, null, settings.binarySource)
494500
val dataCLIMatches = dataCLI.matchesVersion(buildVersion)
495501
if (dataCLIMatches == true) {
496502
return dataCLI

src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt

+237-5
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace
1414
import com.coder.gateway.sdk.v2.models.WorkspaceBuild
1515
import com.coder.gateway.sdk.v2.models.WorkspaceTransition
1616
import com.coder.gateway.sdk.v2.models.toAgentModels
17+
import com.coder.gateway.services.CoderSettingsState
1718
import com.google.gson.Gson
1819
import com.google.gson.GsonBuilder
1920
import com.intellij.ide.plugins.PluginManagerCore
2021
import com.intellij.openapi.components.Service
2122
import com.intellij.openapi.extensions.PluginId
2223
import com.intellij.openapi.util.SystemInfo
2324
import okhttp3.OkHttpClient
25+
import okhttp3.internal.tls.OkHostnameVerifier
2426
import okhttp3.logging.HttpLoggingInterceptor
2527
import org.zeroturnaround.exec.ProcessExecutor
2628
import retrofit2.Retrofit
2729
import retrofit2.converter.gson.GsonConverterFactory
30+
import java.io.File
31+
import java.io.FileInputStream
2832
import java.net.HttpURLConnection.HTTP_CREATED
33+
import java.net.InetAddress
34+
import java.net.Socket
2935
import java.net.URL
36+
import java.nio.file.Path
37+
import java.security.KeyFactory
38+
import java.security.KeyStore
39+
import java.security.PrivateKey
40+
import java.security.cert.CertificateException
41+
import java.security.cert.CertificateFactory
42+
import java.security.cert.X509Certificate
43+
import java.security.spec.InvalidKeySpecException
44+
import java.security.spec.PKCS8EncodedKeySpec
3045
import java.time.Instant
46+
import java.util.Base64
47+
import java.util.Locale
3148
import java.util.UUID
49+
import javax.net.ssl.HostnameVerifier
50+
import javax.net.ssl.KeyManagerFactory
51+
import javax.net.ssl.SNIHostName
52+
import javax.net.ssl.SSLContext
53+
import javax.net.ssl.SSLSession
54+
import javax.net.ssl.SSLSocket
55+
import javax.net.ssl.SSLSocketFactory
56+
import javax.net.ssl.TrustManagerFactory
57+
import javax.net.ssl.TrustManager
58+
import javax.net.ssl.X509TrustManager
3259

3360
@Service(Service.Level.APP)
3461
class CoderRestClientService {
@@ -44,18 +71,19 @@ class CoderRestClientService {
4471
*
4572
* @throws [AuthenticationResponseException] if authentication failed.
4673
*/
47-
fun initClientSession(url: URL, token: String, headerCommand: String?): User {
48-
client = CoderRestClient(url, token, headerCommand, null)
74+
fun initClientSession(url: URL, token: String, settings: CoderSettingsState): User {
75+
client = CoderRestClient(url, token, null, settings)
4976
me = client.me()
5077
buildVersion = client.buildInfo().version
5178
isReady = true
5279
return me
5380
}
5481
}
5582

56-
class CoderRestClient(var url: URL, var token: String,
57-
private var headerCommand: String?,
83+
class CoderRestClient(
84+
var url: URL, var token: String,
5885
private var pluginVersion: String?,
86+
private var settings: CoderSettingsState,
5987
) {
6088
private var httpClient: OkHttpClient
6189
private var retroRestClient: CoderV2RestFacade
@@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String,
6694
pluginVersion = PluginManagerCore.getPlugin(PluginId.getId("com.coder.gateway"))!!.version // this is the id from the plugin.xml
6795
}
6896

97+
val socketFactory = coderSocketFactory(settings)
98+
val trustManagers = coderTrustManagers(settings.tlsCAPath)
6999
httpClient = OkHttpClient.Builder()
100+
.sslSocketFactory(socketFactory, trustManagers[0] as X509TrustManager)
101+
.hostnameVerifier(CoderHostnameVerifier(settings.tlsAlternateHostname))
70102
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("Coder-Session-Token", token).build()) }
71103
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("User-Agent", "Coder Gateway/${pluginVersion} (${SystemInfo.getOsNameAndVersion()}; ${SystemInfo.OS_ARCH})").build()) }
72104
.addInterceptor {
73105
var request = it.request()
74-
val headers = getHeaders(url, headerCommand)
106+
val headers = getHeaders(url, settings.headerCommand)
75107
if (headers.size > 0) {
76108
val builder = request.newBuilder()
77109
headers.forEach { h -> builder.addHeader(h.key, h.value) }
@@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String,
218250
}
219251
}
220252
}
253+
254+
fun coderSocketFactory(settings: CoderSettingsState) : SSLSocketFactory {
255+
if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) {
256+
return SSLSocketFactory.getDefault() as SSLSocketFactory
257+
}
258+
259+
val certificateFactory = CertificateFactory.getInstance("X.509")
260+
val certInputStream = FileInputStream(expandPath(settings.tlsCertPath))
261+
val certChain = certificateFactory.generateCertificates(certInputStream)
262+
certInputStream.close()
263+
264+
// ideally we would use something like PemReader from BouncyCastle, but
265+
// BC is used by the IDE. This makes using BC very impractical since
266+
// type casting will mismatch due to the different class loaders.
267+
val privateKeyPem = File(expandPath(settings.tlsKeyPath)).readText()
268+
val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----")
269+
val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start)
270+
val pemBytes: ByteArray = Base64.getDecoder().decode(
271+
privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end)
272+
.replace("\\s+".toRegex(), "")
273+
)
274+
275+
var privateKey : PrivateKey
276+
try {
277+
val kf = KeyFactory.getInstance("RSA")
278+
val keySpec = PKCS8EncodedKeySpec(pemBytes)
279+
privateKey = kf.generatePrivate(keySpec)
280+
} catch (e: InvalidKeySpecException) {
281+
val kf = KeyFactory.getInstance("EC")
282+
val keySpec = PKCS8EncodedKeySpec(pemBytes)
283+
privateKey = kf.generatePrivate(keySpec)
284+
}
285+
286+
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType())
287+
keyStore.load(null)
288+
certChain.withIndex().forEach {
289+
keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
290+
}
291+
keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray())
292+
293+
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
294+
keyManagerFactory.init(keyStore, null)
295+
296+
val sslContext = SSLContext.getInstance("TLS")
297+
298+
val trustManagers = coderTrustManagers(settings.tlsCAPath)
299+
sslContext.init(keyManagerFactory.keyManagers, trustManagers, null)
300+
301+
if (settings.tlsAlternateHostname.isBlank()) {
302+
return sslContext.socketFactory
303+
}
304+
305+
return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.tlsAlternateHostname)
306+
}
307+
308+
fun coderTrustManagers(tlsCAPath: String) : Array<TrustManager> {
309+
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
310+
if (tlsCAPath.isBlank()) {
311+
// return default trust managers
312+
trustManagerFactory.init(null as KeyStore?)
313+
return trustManagerFactory.trustManagers
314+
}
315+
316+
317+
val certificateFactory = CertificateFactory.getInstance("X.509")
318+
val caInputStream = FileInputStream(expandPath(tlsCAPath))
319+
val certChain = certificateFactory.generateCertificates(caInputStream)
320+
321+
val truststore = KeyStore.getInstance(KeyStore.getDefaultType())
322+
truststore.load(null)
323+
certChain.withIndex().forEach {
324+
truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
325+
}
326+
trustManagerFactory.init(truststore)
327+
return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray()
328+
}
329+
330+
fun expandPath(path: String): String {
331+
if (path.startsWith("~/")) {
332+
return Path.of(System.getProperty("user.home"), path.substring(1)).toString()
333+
}
334+
if (path.startsWith("\$HOME/")) {
335+
return Path.of(System.getProperty("user.home"), path.substring(5)).toString()
336+
}
337+
if (path.startsWith("\${user.home}/")) {
338+
return Path.of(System.getProperty("user.home"), path.substring(12)).toString()
339+
}
340+
return path
341+
}
342+
343+
class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() {
344+
override fun getDefaultCipherSuites(): Array<String> {
345+
return delegate.defaultCipherSuites
346+
}
347+
348+
override fun getSupportedCipherSuites(): Array<String> {
349+
return delegate.supportedCipherSuites
350+
}
351+
352+
override fun createSocket(): Socket {
353+
val socket = delegate.createSocket() as SSLSocket
354+
customizeSocket(socket)
355+
return socket
356+
}
357+
358+
override fun createSocket(host: String?, port: Int): Socket {
359+
val socket = delegate.createSocket(host, port) as SSLSocket
360+
customizeSocket(socket)
361+
return socket
362+
}
363+
364+
override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket {
365+
val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
366+
customizeSocket(socket)
367+
return socket
368+
}
369+
370+
override fun createSocket(host: InetAddress?, port: Int): Socket {
371+
val socket = delegate.createSocket(host, port) as SSLSocket
372+
customizeSocket(socket)
373+
return socket
374+
}
375+
376+
override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket {
377+
val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
378+
customizeSocket(socket)
379+
return socket
380+
}
381+
382+
override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket {
383+
val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
384+
customizeSocket(socket)
385+
return socket
386+
}
387+
388+
private fun customizeSocket(socket: SSLSocket) {
389+
val params = socket.sslParameters
390+
params.serverNames = listOf(SNIHostName(alternateName))
391+
socket.sslParameters = params
392+
}
393+
}
394+
395+
class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier {
396+
override fun verify(host: String, session: SSLSession): Boolean {
397+
if (alternateName.isEmpty()) {
398+
println("using default hostname verifier, alternateName is empty")
399+
return OkHostnameVerifier.verify(host, session)
400+
}
401+
println("Looking for alternate hostname: $alternateName")
402+
val certs = session.peerCertificates ?: return false
403+
for (cert in certs) {
404+
if (cert !is X509Certificate) {
405+
continue
406+
}
407+
val entries = cert.subjectAlternativeNames ?: continue
408+
for (entry in entries) {
409+
val kind = entry[0] as Int
410+
if (kind != 2) { // DNS Name
411+
continue
412+
}
413+
val hostname = entry[1] as String
414+
println("Found cert hostname: $hostname")
415+
if (hostname.lowercase(Locale.getDefault()) == alternateName) {
416+
return true
417+
}
418+
}
419+
}
420+
println("No matching hostname found")
421+
return false
422+
}
423+
}
424+
425+
class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager {
426+
private val systemTrustManager : X509TrustManager
427+
init {
428+
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
429+
trustManagerFactory.init(null as KeyStore?)
430+
systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
431+
}
432+
433+
override fun checkClientTrusted(chain: Array<out X509Certificate>, authType: String?) {
434+
try {
435+
otherTrustManager.checkClientTrusted(chain, authType)
436+
} catch (e: CertificateException) {
437+
systemTrustManager.checkClientTrusted(chain, authType)
438+
}
439+
}
440+
441+
override fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String?) {
442+
try {
443+
otherTrustManager.checkServerTrusted(chain, authType)
444+
} catch (e: CertificateException) {
445+
systemTrustManager.checkServerTrusted(chain, authType)
446+
}
447+
}
448+
449+
override fun getAcceptedIssuers(): Array<X509Certificate> {
450+
return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers
451+
}
452+
}

0 commit comments

Comments
 (0)