AsynchronousSocketChannelExt.kt
package org.knio.core.nio
import org.knio.core.utils.fromResult
import java.io.IOException
import java.net.SocketException
import java.net.SocketTimeoutException
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.InterruptedByTimeoutException
import java.util.concurrent.TimeUnit
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine
/**
* Handles errors by throwing a `SocketTimeoutException` if the error is an `InterruptedByTimeoutException`,
* otherwise rethrows the original exception.
*
* @param e The throwable to handle.
* @return Nothing, as this function always throws an exception.
* @throws SocketTimeoutException if the throwable is an `InterruptedByTimeoutException`.
* @throws Throwable the original throwable if it is not an `InterruptedByTimeoutException`.
*/
private fun <T> errorHandler(e: Throwable): T {
when (e) {
is InterruptedByTimeoutException -> throw SocketTimeoutException("Connection timed out")
is IOException -> throw SocketException(e.message, e)
else -> throw e
}
}
/**
* Suspends the coroutine and reads data from the `AsynchronousSocketChannel` into the provided `ByteBuffer`.
*
* @param b The `ByteBuffer` to read data into.
* @param timeout The timeout in milliseconds, or `null` for no timeout.
* @return The number of bytes read.
* @throws SocketTimeoutException if the read operation times out.
* @throws Throwable if any other error occurs during the read operation.
*/
suspend fun AsynchronousSocketChannel.readSuspend(
b: ByteBuffer,
timeout: Long? = null
): Int = suspendCoroutine {
try {
// Call the callback version of the non-blocking read function, un-suspending the coroutine when complete.
if (timeout != null && timeout > 0) {
// with timeout
read(b, timeout, TimeUnit.MILLISECONDS, it, fromResult(onFail = ::errorHandler))
} else {
// without timeout
read(b, it, fromResult(onFail = ::errorHandler))
}
} catch (e: Throwable) {
it.resumeWithException(e)
}
}
/**
* Suspends the coroutine and writes data from the provided `ByteBuffer` to the `AsynchronousSocketChannel`.
*
* @param b The `ByteBuffer` containing data to write.
* @param timeout The timeout in milliseconds, or `null` for no timeout.
* @return The number of bytes written.
* @throws SocketTimeoutException if the write operation times out.
* @throws Throwable if any other error occurs during the write operation.
*/
suspend fun AsynchronousSocketChannel.writeSuspend(
b: ByteBuffer,
timeout: Long? = null
): Int = suspendCoroutine {
try {
// Call the callback version of the non-blocking write function, un-suspending the coroutine when complete.
if (timeout != null && timeout > 0) {
// with timeout
write(b, timeout, TimeUnit.MILLISECONDS, it, fromResult(onFail = ::errorHandler))
} else {
// without timeout
write(b, it, fromResult(onFail = ::errorHandler))
}
} catch (e: Throwable) {
it.resumeWithException(e)
}
}