KSSLSocketImpl.kt
package org.knio.core.net.ssl
import kotlinx.coroutines.delay
import kotlinx.coroutines.sync.withLock
import org.knio.core.nio.readSuspend
import org.knio.core.nio.writeSuspend
import org.knio.core.context.KnioContext
import org.knio.core.context.ReleasableBuffer
import org.knio.core.context.acquireReleasableByteBuffer
import org.knio.core.io.KInputStream
import org.knio.core.io.KOutputStream
import java.io.IOException
import java.net.SocketException
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.ClosedChannelException
import javax.net.ssl.*
import kotlin.math.min
internal class KSSLSocketImpl (
channel: AsynchronousSocketChannel,
sslEngine: SSLEngine,
useClientMode: Boolean,
private val context: KnioContext
): KSSLSocketAbstract(
channel,
sslEngine,
useClientMode
) {
private var isInputShutdown = false
private val networkRead = ReadWriteBuffer(context.byteBufferPool.acquireReleasableByteBuffer(sslEngine.session.packetBufferSize))
private val application = ReadWriteBuffer(context.byteBufferPool.acquireReleasableByteBuffer(sslEngine.session.applicationBufferSize))
private var isOutputShutdown = false
private var networkWrite = ReadWriteBuffer(context.byteBufferPool.acquireReleasableByteBuffer(sslEngine.session.packetBufferSize))
private val inputStream = object : KInputStream(context) {
override suspend fun read(b: ByteBuffer): Int {
return this@KSSLSocketImpl.read(b)
}
override suspend fun close() {
this@KSSLSocketImpl.close()
}
}
private val outputStream = object : KOutputStream() {
override suspend fun write(b: ByteBuffer) {
this@KSSLSocketImpl.write(b)
}
override suspend fun close() {
this@KSSLSocketImpl.close()
}
}
override suspend fun getInputStream(): KInputStream = lock.withLock {
if(!ch.isOpen) {
throw SocketException("Socket is closed")
}
if(isInputShutdown) {
throw SocketException("Socket input is shutdown")
}
return inputStream
}
override suspend fun getOutputStream(): KOutputStream = lock.withLock {
if(isOutputShutdown) {
throw SocketException("Socket output is shutdown")
}
return outputStream
}
override suspend fun softStartHandshake() {
// For internal use only. This should not acquire the lock.
if(!sslEngine.session.isValid) {
startHandshake0()
}
}
override suspend fun startHandshake() = lock.withLock {
// initiates or renegotiates the SSL handshake
startHandshake0()
}
/**
* Same as [KSSLSocket.startHandshake] except that this is an internal function that executes without
* acquiring the lock.
*
* @see [KSSLSocket.startHandshake]
*/
private suspend fun startHandshake0() {
@Suppress("BlockingMethodInNonBlockingContext")
sslEngine.beginHandshake()
handleHandshake0()
}
/**
* Handles the handshake process.
*
* A handshake may be initiated at any time, and may be initiated multiple times. This method will handle processing
* the handshake until it is complete.
*/
private suspend fun handleHandshake0(force: Boolean = false) {
/**
* In rare situations, a handshake may be triggered with NEEDS_TASK, NEEDS_WRAP or NEEDS_UNWRAP but will never
* materialize into a full handshake session.
*
* Perform the required task then return unless a handshake session is available. If a handshake session is
* available, then the handshake is in progress and we should continue.
*/
if(!sslEngine.isHandshaking) {
return
}
var handshakeSession: SSLSession? = sslEngine.handshakeSession
handshakeSession?.let { initBuffersForHandshake(it) }
do {
handshakeIteration0()
handshakeSession = handshakeSession ?: sslEngine.handshakeSession?.also { initBuffersForHandshake(it) }
} while (sslEngine.isHandshaking && (handshakeSession != null || !sslEngine.session.isValid))
if(handshakeSession != null) {
super.triggerHandshakeCompletion(handshakeSession)
}
}
private suspend fun initBuffersForHandshake(session: SSLSession) {
if(networkRead.value.capacity()<session.packetBufferSize) {
networkRead.releasable.resize(session.packetBufferSize)
}
if(networkWrite.value.capacity()<session.packetBufferSize) {
networkWrite.releasable.resize(session.packetBufferSize)
}
if(application.value.capacity()<session.applicationBufferSize) {
application.releasable.resize(session.applicationBufferSize)
}
}
private suspend fun handshakeIteration0() {
when(sslEngine.handshakeStatus!!) {
SSLEngineResult.HandshakeStatus.NEED_TASK -> {
runHandshakeTasks()
}
SSLEngineResult.HandshakeStatus.NEED_WRAP -> {
wrapHandshake()
}
SSLEngineResult.HandshakeStatus.NEED_UNWRAP,
SSLEngineResult.HandshakeStatus.NEED_UNWRAP_AGAIN-> {
unwrapHandshake()
}
SSLEngineResult.HandshakeStatus.FINISHED,
SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING -> {
// DONE!
}
}
}
private fun runHandshakeTasks() {
while (true) {
val task = sslEngine.delegatedTask ?: break
task.run()
}
}
private suspend fun wrapHandshake() {
wrap(src = ByteBuffer.wrap(ByteArray(0)))
}
private suspend fun wrap (
src: ByteBuffer,
dst: ReadWriteBuffer = networkWrite
) {
while (true) {
@Suppress("BlockingMethodInNonBlockingContext")
val result = sslEngine.wrap(src, dst.write)
when (result.status!!) {
SSLEngineResult.Status.BUFFER_UNDERFLOW -> {
// Dummy buffer used. If thrown, bad assumptions made
throw SSLException("Buffer underflow while wrapping in handshake")
}
SSLEngineResult.Status.BUFFER_OVERFLOW -> {
handleOverflow(dst)
}
SSLEngineResult.Status.OK -> {
// Unwrap was successful. Write the data to the channel.
dst.toMode(ReadWriteBuffer.Mode.READ)
while (dst.value.hasRemaining()) {
val read = ch.writeSuspend(dst.value, getWriteTimeout())
if (read == -1) {
throw SSLException("Connection closed during handshake")
}
if (read == 0) {
// TODO
throw SSLException("?? no data written during handshake. try again or error ??")
}
}
break
}
SSLEngineResult.Status.CLOSED -> {
// closed
throw SSLException("Connection closed during handshake")
}
}
}
}
private suspend fun unwrapHandshake() {
while (true) {
// try to unwrap data from the network buffer
@Suppress("BlockingMethodInNonBlockingContext")
val result = sslEngine.unwrap(networkRead.read, application.write)
when (result.status!!) {
SSLEngineResult.Status.BUFFER_UNDERFLOW -> {
// If there's no room to read, increase buffer size
if (!networkRead.write.hasRemaining()) {
val buffer = networkRead.releasable
buffer.resize(buffer.value.capacity() + sslEngine.session.packetBufferSize)
}
// Read more data from the channel
val count = ch.readSuspend(networkRead.write, getReadTimeout())
if (count == -1) {
throw SSLException("Connection closed during handshake")
}
if (count == 0) {
throw SSLException("?? no data read during handshake. try again or error ??")
}
}
SSLEngineResult.Status.BUFFER_OVERFLOW -> {
handleOverflow(networkRead)
}
SSLEngineResult.Status.OK -> {
// unwrap was successful. leave the data in the network buffer for the next unwrap
break
}
SSLEngineResult.Status.CLOSED -> {
// closed
throw SSLException("Connection closed during handshake")
}
}
}
}
/**
* Handles the BUFFER_OVERFLOW scenario when wrapping or unwrapping ssl content.
*
*/
private fun handleOverflow(buffer: ReadWriteBuffer) {
/**
* The buffer should be in write mode. That is, it's adding data to the buffer.
* The data already in the buffer, from index 0 to index `position()` is data
* written, waiting to be processed. The data from `position()+1` to `limit()`
* is the space we're allowed to write. The space from `limit()` to `capacity()`
* is unusable space.
*/
require(buffer.mode.isWrite())
val limit = buffer.value.limit()
val capacity = buffer.value.capacity()
if (limit == capacity) {
/**
* If limit is capacity, then the wrap/unwrap failed with the maximum amount of space.
* We need to make the buffer bigger.
*/
buffer.releasable.resize(capacity + sslEngine.session.packetBufferSize)
} else {
/**
* If limit is not capacity, then there's unused space. Utilize that space and try
* again
*/
buffer.value.limit(capacity)
}
}
override suspend fun isInputShutdown(): Boolean = lock.withLock {
return isInputShutdown
}
override suspend fun isOutputShutdown(): Boolean = lock.withLock {
return isOutputShutdown
}
override suspend fun shutdownInput() = lock.withLock {
shutdownInput0()
}
private suspend fun shutdownInput0() {
if(networkRead.releasable.released) {
return
}
try {
@Suppress("BlockingMethodInNonBlockingContext")
sslEngine.closeInbound()
// Clear buffer for reuse or release
networkRead.value.clear()
} finally {
isInputShutdown = true
networkRead.releasable.release()
}
}
override suspend fun shutdownOutput() = lock.withLock {
shutdownOutput0()
}
private suspend fun shutdownOutput0() {
try {
sslEngine.closeOutbound()
out@ while (true) {
@Suppress("BlockingMethodInNonBlockingContext")
val result = sslEngine.wrap(ByteBuffer.allocate(0), networkWrite.write)
when (result.status!!) {
SSLEngineResult.Status.BUFFER_OVERFLOW -> {
handleOverflow(networkWrite)
}
SSLEngineResult.Status.OK -> {
try {
networkWrite.toMode(ReadWriteBuffer.Mode.READ)
while (networkWrite.value.hasRemaining()) {
var written = 0
repeat(3) { attempt ->
written = ch.writeSuspend(networkWrite.value)
if (written > 0) return@repeat
delay(100L * attempt) // Backoff delay
}
if (written <= 0) {
break@out
}
}
networkWrite.value.clear()
break
} catch (e: ClosedChannelException) {
// ignore
} catch (e: IOException) {
throw e
}
}
SSLEngineResult.Status.CLOSED -> {
// closed
break@out
}
else -> {
throw SSLException("Unexpected SSL wrap status: ${result.status}")
}
}
}
try {
@Suppress("BlockingMethodInNonBlockingContext")
ch.shutdownOutput()
} catch (e: ClosedChannelException) {
// ignore
} catch (e: IOException) {
throw e
}
} finally {
isOutputShutdown = true
networkWrite.releasable.release()
}
}
private suspend fun read(b: ByteBuffer): Int = lock.withLock {
read0(b)
}
/**
* @implNote The `application` buffer must be empty or in a "read state" when exiting this method
*
* Buffer States:
* - Undefined: Buffer Empty
* - Read State: Bytes are read FROM the buffer
* - Write State: Bytes are written TO the buffer
*/
private suspend fun read0(b: ByteBuffer): Int {
if(isInputShutdown && !application.read.hasRemaining()) {
return -1
}
if (application.releasable.released || networkRead.releasable.released) {
return -1
}
val app = application
val net = networkRead
if(!sslEngine.session.isValid) {
startHandshake0() // <-- flips application buffer to read mode
}
val start = b.position()
input@ while(b.hasRemaining()) {
// Add remaining application data to the buffer
if(app.read.hasRemaining()) {
val count = min(app.value.remaining(), b.remaining())
b.put(b.position(), app.value, app.value.position(), count)
app.value.position(app.value.position() + count)
b.position(b.position() + count)
continue
}
// Check if we're handshaking (could be initiated at any time, any number of times)
if(sslEngine.isHandshaking) {
handleHandshake0()
continue@input
}
if(net.read.hasRemaining()) {
while(true) {
@Suppress("BlockingMethodInNonBlockingContext")
val result = sslEngine.unwrap(net.read, app.write)
when (result.status!!) {
SSLEngineResult.Status.BUFFER_UNDERFLOW -> {
// Not enough data to read for the unwrap operation.
// If there's no more room to write, we need to expand the buffer.
// Otherwise, break out and allow for more data to be written.
if(!net.write.hasRemaining()) {
val buffer = net.releasable
buffer.resize(buffer.value.capacity() + sslEngine.session.packetBufferSize)
}
// Read more data from the channel
if(readChannel()<=0) {
break@input
}
}
SSLEngineResult.Status.BUFFER_OVERFLOW -> {
handleOverflow(app)
}
SSLEngineResult.Status.OK -> {
break
}
SSLEngineResult.Status.CLOSED -> {
shutdownInput0()
break@input
}
}
}
} else {
if(readChannel()<=0) {
break@input
}
}
}
// In all cases the method exists with the application buffer in read mode or empty
return if(b.position() == start) {
if(isInputShutdown) -1 else 0
} else {
b.position() - start
}
}
private suspend fun write(b: ByteBuffer) = lock.withLock {
write0(b)
}
private suspend fun write0(b: ByteBuffer) {
if(!sslEngine.session.isValid) {
startHandshake0()
}
while(b.hasRemaining()) {
// Check if we're handshaking (could be initiated at any time, any number of times)
if(sslEngine.isHandshaking) {
handleHandshake0()
continue
}
wrap(src = b)
}
}
private suspend fun readChannel(b: ReadWriteBuffer = networkRead): Int {
val count = ch.readSuspend(b.write, getReadTimeout())
if(count == -1) {
shutdownInput0()
}
return count
}
/**
* Returns true if the SSLEngine is handshaking.
*/
private val SSLEngine.isHandshaking: Boolean
get() = this.handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED
&& this.handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING
/**
* Keeps track of the read/write state of the buffer.
*
* - *Read Mode* is defined as the state in which data is read from this buffer
* - *Write Mode* is defined as the state in which data is written into this buffer
*
* @param releasable The buffer to manage
* @param mode The initial mode of the buffer - Defaults to WRITE. Default assumes that
* the buffer is clear. The position is set to 0 and the limit is set to the capacity.
*/
private class ReadWriteBuffer(
val releasable: ReleasableBuffer<ByteBuffer>,
var mode: Mode = Mode.WRITE
) {
enum class Mode {
READ, WRITE;
fun isRead() = this === READ
fun isWrite() = this === WRITE
}
/** Returns the buffer without changing the mode */
val value: ByteBuffer get() = releasable.value
/** Returns the buffer in read-mode */
val read: ByteBuffer get() {
toMode(Mode.READ)
return releasable.value
}
/** Returns the buffer in write-mode */
val write: ByteBuffer get() {
toMode(Mode.WRITE)
return releasable.value
}
/**
* Swaps the buffer between read and write mode, preparing it for the opposite operation.
*/
fun swap(): ByteBuffer {
val buffer = releasable.value
if(mode.isRead()) {
mode = Mode.WRITE
return buffer.compact()
} else {
mode = Mode.READ
return buffer.flip()
}
}
/**
* Sets the buffer to the specified mode. If the buffer is already in the specified mode, this method does
* nothing, otherwise it swaps the buffer to the opposite mode.
*/
fun toMode(mode: Mode) {
if(this.mode != mode) {
swap()
}
}
}
}