KSSLSocketAbstract.kt
package org.knio.core.net.ssl
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.knio.core.net.KSocketAbstract
import java.nio.channels.AsynchronousSocketChannel
import java.util.function.BiFunction
import javax.net.ssl.*
import kotlin.coroutines.coroutineContext
internal abstract class KSSLSocketAbstract(
channel: AsynchronousSocketChannel,
protected val sslEngine: SSLEngine,
useClientMode: Boolean,
): KSSLSocket, KSocketAbstract(channel) {
protected val lock: Mutex = Mutex()
private val handshakeListeners: MutableSet<KHandshakeCompletedListener> = mutableSetOf()
init {
sslEngine.useClientMode = useClientMode
}
override suspend fun getSupportedCipherSuites(): Array<String> = lock.withLock {
return sslEngine.supportedCipherSuites
}
override suspend fun getEnabledCipherSuites(): Array<String> = lock.withLock {
return sslEngine.enabledCipherSuites
}
override suspend fun setEnabledCipherSuites(suites: Array<String>) = lock.withLock {
sslEngine.enabledCipherSuites = suites
}
override suspend fun getSupportedProtocols(): Array<String> = lock.withLock {
return sslEngine.supportedProtocols
}
override suspend fun getEnabledProtocols(): Array<String> = lock.withLock {
return sslEngine.enabledProtocols
}
override suspend fun setEnabledProtocols(protocols: Array<String>) = lock.withLock {
sslEngine.enabledProtocols = protocols
}
override suspend fun getSession(): SSLSession = lock.withLock {
softStartHandshake()
return sslEngine.session
}
override suspend fun getHandshakeSession(): SSLSession? = lock.withLock {
return sslEngine.handshakeSession
}
override suspend fun addHandshakeCompletedListener(listener: KHandshakeCompletedListener): Unit = lock.withLock {
handshakeListeners.add(listener)
}
override suspend fun removeHandshakeCompletedListener(listener: KHandshakeCompletedListener): Unit = lock.withLock {
handshakeListeners.remove(listener)
}
override suspend fun setUseClientMode(mode: Boolean) = lock.withLock {
sslEngine.useClientMode = mode
}
override suspend fun getUseClientMode(): Boolean = lock.withLock {
return sslEngine.useClientMode
}
override suspend fun setNeedClientAuth(need: Boolean) = lock.withLock {
sslEngine.needClientAuth = need
}
override suspend fun getNeedClientAuth(): Boolean = lock.withLock {
return sslEngine.needClientAuth
}
override suspend fun setWantClientAuth(want: Boolean) = lock.withLock {
sslEngine.wantClientAuth = want
}
override suspend fun getWantClientAuth(): Boolean = lock.withLock {
return sslEngine.wantClientAuth
}
override suspend fun setEnableSessionCreation(flag: Boolean) = lock.withLock {
sslEngine.enableSessionCreation = flag
}
override suspend fun getEnableSessionCreation(): Boolean = lock.withLock {
return sslEngine.enableSessionCreation
}
override suspend fun getApplicationProtocol(): String? = lock.withLock {
return sslEngine.applicationProtocol
}
override suspend fun getHandshakeApplicationProtocol(): String = lock.withLock {
return sslEngine.handshakeApplicationProtocol
}
override suspend fun setHandshakeApplicationProtocolSelector(
selector: BiFunction<KSSLSocket, List<String>, String?>?
) = lock.withLock {
if(selector == null) {
sslEngine.handshakeApplicationProtocolSelector = null
} else {
sslEngine.handshakeApplicationProtocolSelector = HandshakeApplicationProtocolSelector(selector)
}
}
override suspend fun getHandshakeApplicationProtocolSelector(): BiFunction<KSSLSocket, List<String>, String?>? = lock.withLock {
return sslEngine.handshakeApplicationProtocolSelector?.let {
if(it is HandshakeApplicationProtocolSelector) {
it.selector
} else {
null
}
}
}
override suspend fun getSSLParameters(): SSLParameters = lock.withLock {
return sslEngine.sslParameters
}
override suspend fun setSSLParameters(params: SSLParameters) = lock.withLock {
sslEngine.sslParameters = params
}
/**
* Starts the handshake process if it has not already been started.
*
* Note: This function should not acquire the lock. The calling function will have already acquired the lock.
*/
protected abstract suspend fun softStartHandshake();
/**
* Must be called after the handshake is complete.
*/
protected suspend fun triggerHandshakeCompletion(session: SSLSession) {
// get the listeners and close it so no more can be added
if(handshakeListeners.isEmpty()) {
return
}
// run the listeners in a separate coroutine
CoroutineScope(coroutineContext).launch {
handshakeListeners.forEach {
runHandshakeCompletedListener(it, session)
}
}
}
private suspend fun runHandshakeCompletedListener(listener: KHandshakeCompletedListener, session: SSLSession) {
try {
listener.handshakeCompleted(
KHandshakeCompletedEvent(
this,
session
)
)
} catch (th: Throwable) {
// nothing left to do. print the stack trace and move on
th.printStackTrace()
}
}
/**
* A wrapper for the application protocol selector function to work with the SSLEngine.
*
* @see SSLEngine.getHandshakeApplicationProtocolSelector
* @see SSLEngine.setHandshakeApplicationProtocolSelector
*/
private class HandshakeApplicationProtocolSelector (
val selector: BiFunction<KSSLSocket, List<String>, String?>
): BiFunction<SSLEngine, List<String>, String?> {
override fun apply(t: SSLEngine, u: List<String>): String? {
return selector.apply(t as KSSLSocket, u)
}
}
}