@@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace
14
14
import com.coder.gateway.sdk.v2.models.WorkspaceBuild
15
15
import com.coder.gateway.sdk.v2.models.WorkspaceTransition
16
16
import com.coder.gateway.sdk.v2.models.toAgentModels
17
+ import com.coder.gateway.services.CoderSettingsState
17
18
import com.google.gson.Gson
18
19
import com.google.gson.GsonBuilder
19
20
import com.intellij.ide.plugins.PluginManagerCore
20
21
import com.intellij.openapi.components.Service
21
22
import com.intellij.openapi.extensions.PluginId
22
23
import com.intellij.openapi.util.SystemInfo
23
24
import okhttp3.OkHttpClient
25
+ import okhttp3.internal.tls.OkHostnameVerifier
24
26
import okhttp3.logging.HttpLoggingInterceptor
25
27
import org.zeroturnaround.exec.ProcessExecutor
26
28
import retrofit2.Retrofit
27
29
import retrofit2.converter.gson.GsonConverterFactory
30
+ import java.io.File
31
+ import java.io.FileInputStream
28
32
import java.net.HttpURLConnection.HTTP_CREATED
33
+ import java.net.InetAddress
34
+ import java.net.Socket
29
35
import java.net.URL
36
+ import java.nio.file.Path
37
+ import java.security.KeyFactory
38
+ import java.security.KeyStore
39
+ import java.security.PrivateKey
40
+ import java.security.cert.CertificateException
41
+ import java.security.cert.CertificateFactory
42
+ import java.security.cert.X509Certificate
43
+ import java.security.spec.InvalidKeySpecException
44
+ import java.security.spec.PKCS8EncodedKeySpec
30
45
import java.time.Instant
46
+ import java.util.Base64
47
+ import java.util.Locale
31
48
import java.util.UUID
49
+ import javax.net.ssl.HostnameVerifier
50
+ import javax.net.ssl.KeyManagerFactory
51
+ import javax.net.ssl.SNIHostName
52
+ import javax.net.ssl.SSLContext
53
+ import javax.net.ssl.SSLSession
54
+ import javax.net.ssl.SSLSocket
55
+ import javax.net.ssl.SSLSocketFactory
56
+ import javax.net.ssl.TrustManagerFactory
57
+ import javax.net.ssl.TrustManager
58
+ import javax.net.ssl.X509TrustManager
32
59
33
60
@Service(Service .Level .APP )
34
61
class CoderRestClientService {
@@ -44,18 +71,19 @@ class CoderRestClientService {
44
71
*
45
72
* @throws [AuthenticationResponseException] if authentication failed.
46
73
*/
47
- fun initClientSession (url : URL , token : String , headerCommand : String? ): User {
48
- client = CoderRestClient (url, token, headerCommand, null )
74
+ fun initClientSession (url : URL , token : String , settings : CoderSettingsState ): User {
75
+ client = CoderRestClient (url, token, null , settings )
49
76
me = client.me()
50
77
buildVersion = client.buildInfo().version
51
78
isReady = true
52
79
return me
53
80
}
54
81
}
55
82
56
- class CoderRestClient (var url : URL , var token : String ,
57
- private var headerCommand : String? ,
83
+ class CoderRestClient (
84
+ var url : URL , var token : String ,
58
85
private var pluginVersion : String? ,
86
+ private var settings : CoderSettingsState ,
59
87
) {
60
88
private var httpClient: OkHttpClient
61
89
private var retroRestClient: CoderV2RestFacade
@@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String,
66
94
pluginVersion = PluginManagerCore .getPlugin(PluginId .getId(" com.coder.gateway" ))!! .version // this is the id from the plugin.xml
67
95
}
68
96
97
+ val socketFactory = coderSocketFactory(settings)
98
+ val trustManagers = coderTrustManagers(settings.tlsCAPath)
69
99
httpClient = OkHttpClient .Builder ()
100
+ .sslSocketFactory(socketFactory, trustManagers[0 ] as X509TrustManager )
101
+ .hostnameVerifier(CoderHostnameVerifier (settings.tlsAlternateHostname))
70
102
.addInterceptor { it.proceed(it.request().newBuilder().addHeader(" Coder-Session-Token" , token).build()) }
71
103
.addInterceptor { it.proceed(it.request().newBuilder().addHeader(" User-Agent" , " Coder Gateway/${pluginVersion} (${SystemInfo .getOsNameAndVersion()} ; ${SystemInfo .OS_ARCH } )" ).build()) }
72
104
.addInterceptor {
73
105
var request = it.request()
74
- val headers = getHeaders(url, headerCommand)
106
+ val headers = getHeaders(url, settings. headerCommand)
75
107
if (headers.size > 0 ) {
76
108
val builder = request.newBuilder()
77
109
headers.forEach { h -> builder.addHeader(h.key, h.value) }
@@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String,
218
250
}
219
251
}
220
252
}
253
+
254
+ fun coderSocketFactory (settings : CoderSettingsState ) : SSLSocketFactory {
255
+ if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) {
256
+ return SSLSocketFactory .getDefault() as SSLSocketFactory
257
+ }
258
+
259
+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
260
+ val certInputStream = FileInputStream (expandPath(settings.tlsCertPath))
261
+ val certChain = certificateFactory.generateCertificates(certInputStream)
262
+ certInputStream.close()
263
+
264
+ // ideally we would use something like PemReader from BouncyCastle, but
265
+ // BC is used by the IDE. This makes using BC very impractical since
266
+ // type casting will mismatch due to the different class loaders.
267
+ val privateKeyPem = File (expandPath(settings.tlsKeyPath)).readText()
268
+ val start: Int = privateKeyPem.indexOf(" -----BEGIN PRIVATE KEY-----" )
269
+ val end: Int = privateKeyPem.indexOf(" -----END PRIVATE KEY-----" , start)
270
+ val pemBytes: ByteArray = Base64 .getDecoder().decode(
271
+ privateKeyPem.substring(start + " -----BEGIN PRIVATE KEY-----" .length, end)
272
+ .replace(" \\ s+" .toRegex(), " " )
273
+ )
274
+
275
+ var privateKey : PrivateKey
276
+ try {
277
+ val kf = KeyFactory .getInstance(" RSA" )
278
+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
279
+ privateKey = kf.generatePrivate(keySpec)
280
+ } catch (e: InvalidKeySpecException ) {
281
+ val kf = KeyFactory .getInstance(" EC" )
282
+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
283
+ privateKey = kf.generatePrivate(keySpec)
284
+ }
285
+
286
+ val keyStore = KeyStore .getInstance(KeyStore .getDefaultType())
287
+ keyStore.load(null )
288
+ certChain.withIndex().forEach {
289
+ keyStore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
290
+ }
291
+ keyStore.setKeyEntry(" key" , privateKey, null , certChain.toTypedArray())
292
+
293
+ val keyManagerFactory = KeyManagerFactory .getInstance(KeyManagerFactory .getDefaultAlgorithm())
294
+ keyManagerFactory.init (keyStore, null )
295
+
296
+ val sslContext = SSLContext .getInstance(" TLS" )
297
+
298
+ val trustManagers = coderTrustManagers(settings.tlsCAPath)
299
+ sslContext.init (keyManagerFactory.keyManagers, trustManagers, null )
300
+
301
+ if (settings.tlsAlternateHostname.isBlank()) {
302
+ return sslContext.socketFactory
303
+ }
304
+
305
+ return AlternateNameSSLSocketFactory (sslContext.socketFactory, settings.tlsAlternateHostname)
306
+ }
307
+
308
+ fun coderTrustManagers (tlsCAPath : String ) : Array <TrustManager > {
309
+ val trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
310
+ if (tlsCAPath.isBlank()) {
311
+ // return default trust managers
312
+ trustManagerFactory.init (null as KeyStore ? )
313
+ return trustManagerFactory.trustManagers
314
+ }
315
+
316
+
317
+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
318
+ val caInputStream = FileInputStream (expandPath(tlsCAPath))
319
+ val certChain = certificateFactory.generateCertificates(caInputStream)
320
+
321
+ val truststore = KeyStore .getInstance(KeyStore .getDefaultType())
322
+ truststore.load(null )
323
+ certChain.withIndex().forEach {
324
+ truststore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
325
+ }
326
+ trustManagerFactory.init (truststore)
327
+ return trustManagerFactory.trustManagers.map { MergedSystemTrustManger (it as X509TrustManager ) }.toTypedArray()
328
+ }
329
+
330
+ fun expandPath (path : String ): String {
331
+ if (path.startsWith(" ~/" )) {
332
+ return Path .of(System .getProperty(" user.home" ), path.substring(1 )).toString()
333
+ }
334
+ if (path.startsWith(" \$ HOME/" )) {
335
+ return Path .of(System .getProperty(" user.home" ), path.substring(5 )).toString()
336
+ }
337
+ if (path.startsWith(" \$ {user.home}/" )) {
338
+ return Path .of(System .getProperty(" user.home" ), path.substring(12 )).toString()
339
+ }
340
+ return path
341
+ }
342
+
343
+ class AlternateNameSSLSocketFactory (private val delegate : SSLSocketFactory , private val alternateName : String ) : SSLSocketFactory() {
344
+ override fun getDefaultCipherSuites (): Array <String > {
345
+ return delegate.defaultCipherSuites
346
+ }
347
+
348
+ override fun getSupportedCipherSuites (): Array <String > {
349
+ return delegate.supportedCipherSuites
350
+ }
351
+
352
+ override fun createSocket (): Socket {
353
+ val socket = delegate.createSocket() as SSLSocket
354
+ customizeSocket(socket)
355
+ return socket
356
+ }
357
+
358
+ override fun createSocket (host : String? , port : Int ): Socket {
359
+ val socket = delegate.createSocket(host, port) as SSLSocket
360
+ customizeSocket(socket)
361
+ return socket
362
+ }
363
+
364
+ override fun createSocket (host : String? , port : Int , localHost : InetAddress ? , localPort : Int ): Socket {
365
+ val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
366
+ customizeSocket(socket)
367
+ return socket
368
+ }
369
+
370
+ override fun createSocket (host : InetAddress ? , port : Int ): Socket {
371
+ val socket = delegate.createSocket(host, port) as SSLSocket
372
+ customizeSocket(socket)
373
+ return socket
374
+ }
375
+
376
+ override fun createSocket (address : InetAddress ? , port : Int , localAddress : InetAddress ? , localPort : Int ): Socket {
377
+ val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
378
+ customizeSocket(socket)
379
+ return socket
380
+ }
381
+
382
+ override fun createSocket (s : Socket ? , host : String? , port : Int , autoClose : Boolean ): Socket {
383
+ val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
384
+ customizeSocket(socket)
385
+ return socket
386
+ }
387
+
388
+ private fun customizeSocket (socket : SSLSocket ) {
389
+ val params = socket.sslParameters
390
+ params.serverNames = listOf (SNIHostName (alternateName))
391
+ socket.sslParameters = params
392
+ }
393
+ }
394
+
395
+ class CoderHostnameVerifier (private val alternateName : String ) : HostnameVerifier {
396
+ override fun verify (host : String , session : SSLSession ): Boolean {
397
+ if (alternateName.isEmpty()) {
398
+ println (" using default hostname verifier, alternateName is empty" )
399
+ return OkHostnameVerifier .verify(host, session)
400
+ }
401
+ println (" Looking for alternate hostname: $alternateName " )
402
+ val certs = session.peerCertificates ? : return false
403
+ for (cert in certs) {
404
+ if (cert !is X509Certificate ) {
405
+ continue
406
+ }
407
+ val entries = cert.subjectAlternativeNames ? : continue
408
+ for (entry in entries) {
409
+ val kind = entry[0 ] as Int
410
+ if (kind != 2 ) { // DNS Name
411
+ continue
412
+ }
413
+ val hostname = entry[1 ] as String
414
+ println (" Found cert hostname: $hostname " )
415
+ if (hostname.lowercase(Locale .getDefault()) == alternateName) {
416
+ return true
417
+ }
418
+ }
419
+ }
420
+ println (" No matching hostname found" )
421
+ return false
422
+ }
423
+ }
424
+
425
+ class MergedSystemTrustManger (private val otherTrustManager : X509TrustManager ) : X509TrustManager {
426
+ private val systemTrustManager : X509TrustManager
427
+ init {
428
+ val trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
429
+ trustManagerFactory.init (null as KeyStore ? )
430
+ systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
431
+ }
432
+
433
+ override fun checkClientTrusted (chain : Array <out X509Certificate >, authType : String? ) {
434
+ try {
435
+ otherTrustManager.checkClientTrusted(chain, authType)
436
+ } catch (e: CertificateException ) {
437
+ systemTrustManager.checkClientTrusted(chain, authType)
438
+ }
439
+ }
440
+
441
+ override fun checkServerTrusted (chain : Array <out X509Certificate >, authType : String? ) {
442
+ try {
443
+ otherTrustManager.checkServerTrusted(chain, authType)
444
+ } catch (e: CertificateException ) {
445
+ systemTrustManager.checkServerTrusted(chain, authType)
446
+ }
447
+ }
448
+
449
+ override fun getAcceptedIssuers (): Array <X509Certificate > {
450
+ return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers
451
+ }
452
+ }
0 commit comments