KSocketAbstract.kt

package org.knio.core.net

import org.knio.core.annotations.NotSuspended
import org.knio.core.utils.asCompletionHandler
import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.net.SocketException
import java.nio.channels.AsynchronousSocketChannel
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine

internal val ANY_LOCAL_ADDRESS = InetAddress.getByAddress(byteArrayOf(0, 0, 0, 0))
internal const val UNDEFINED_PORT = 0
internal const val UNDEFINED_LOCAL_PORT = -1

internal abstract class KSocketAbstract(
    protected val ch: AsynchronousSocketChannel
): KSocket {

    /** The read timeout in milliseconds. */
    private var rTimeout: Long? = null

    /** The write timeout in milliseconds. */
    private var wTimeout: Long? = null

    /** Remote Address */
    private var remoteAddress: InetSocketAddress? = null

    /** Local Address */
    private var localAddress: InetSocketAddress? = null

    init {
        @OptIn(NotSuspended::class)
        setProperties()
    }

    @NotSuspended
    private fun setProperties() {
        // not suspended so the properties can be set in the init block, if ready

        if(ch.isOpen) {
            if(this.remoteAddress==null) {
                this.remoteAddress = getRemoteInetSocketAddress()
            }

            if(this.localAddress==null) {
                this.localAddress = getLocalInetSocketAddress()
            }
        }
    }


    override suspend fun bind(local: SocketAddress?) {
        @Suppress("BlockingMethodInNonBlockingContext")
        ch.bind(local)

        @OptIn(NotSuspended::class)
        setProperties()
    }


    override suspend fun close() {
        if (!ch.isOpen) return

        if(isConnected()) {
            try {
                @Suppress("BlockingMethodInNonBlockingContext")
                this.shutdownInput()
            } catch (e: Throwable) {
                // ignore
            }

            try {
                @Suppress("BlockingMethodInNonBlockingContext")
                this.shutdownOutput()
            } catch (e: Throwable) {
                // ignore
            }
        }

        @Suppress("BlockingMethodInNonBlockingContext")
        ch.close()
    }


    override suspend fun connect(endpoint: SocketAddress) {
        connect0(endpoint)

        @OptIn(NotSuspended::class)
        setProperties()
    }

    private suspend fun connect0(endpoint: SocketAddress) = suspendCoroutine {
        try {
            // returns "this" upon completion
            ch.connect(endpoint, it, Unit.asCompletionHandler (onFail = { e ->
                if(e is IOException) {
                    throw SocketException(e.message, e)
                } else {
                    throw e
                }

            }))
        } catch (e: Throwable) {
            it.resumeWithException(e)
        }
    }


    override suspend fun getInetAddress(): java.net.InetAddress {
        return this.remoteAddress?.address ?: ANY_LOCAL_ADDRESS
    }

    @NotSuspended
    private fun getRemoteInetSocketAddress(): InetSocketAddress? {
        val address = ch.remoteAddress ?: return null
        return if(address is java.net.InetSocketAddress) {
            address
        } else {
            null
        }
    }


    override suspend fun getKeepAlive(): Boolean {
        @Suppress("BlockingMethodInNonBlockingContext")
        return ch.getOption(java.net.StandardSocketOptions.SO_KEEPALIVE)
    }


    override suspend fun getLocalAddress(): InetAddress {
        if(!ch.isOpen) {
            return ANY_LOCAL_ADDRESS
        }
        return this.localAddress?.address ?: InetAddress.getLoopbackAddress()
    }

    @NotSuspended
    private fun getLocalInetSocketAddress(): InetSocketAddress? {
        val address = ch.localAddress ?: null
        return if(address is InetSocketAddress) {
            address
        } else {
            null
        }
    }


    override suspend fun getLocalPort(): Int {
        return this.localAddress?.port ?: UNDEFINED_LOCAL_PORT
    }


    override suspend fun getLocalSocketAddress(): SocketAddress? {
        if(!ch.isOpen) {
            assert(ANY_LOCAL_ADDRESS.isAnyLocalAddress)
            return InetSocketAddress(getLocalAddress(), getLocalPort())
        }
        return this.localAddress
    }


    override suspend fun getPort(): Int {
        return this.remoteAddress?.port ?: UNDEFINED_PORT
    }


    override suspend fun getReceiveBufferSize(): Int {
        @Suppress("BlockingMethodInNonBlockingContext")
        return ch.getOption(java.net.StandardSocketOptions.SO_RCVBUF)
    }


    override suspend fun getRemoteSocketAddress(): SocketAddress? = this.remoteAddress


    override suspend fun getReuseAddress(): Boolean {
        @Suppress("BlockingMethodInNonBlockingContext")
        return ch.getOption(java.net.StandardSocketOptions.SO_REUSEADDR)
    }

    override suspend fun getSendBufferSize(): Int{
        @Suppress("BlockingMethodInNonBlockingContext")
        return ch.getOption(java.net.StandardSocketOptions.SO_SNDBUF)
    }

    override suspend fun getReadTimeout(): Long =
        this.rTimeout ?: 0


    override suspend fun getWriteTimeout(): Long =
        this.wTimeout ?: 0


    override suspend fun getTcpNoDelay(): Boolean {
        @Suppress("BlockingMethodInNonBlockingContext")
        return ch.getOption(java.net.StandardSocketOptions.TCP_NODELAY)
    }

    override suspend fun isBound(): Boolean {
        return if(ch.isOpen) {
            ch.localAddress != null
        } else {
            localAddress != null
        }
    }


    override suspend fun isClosed(): Boolean =
        !ch.isOpen


    override suspend fun isConnected(): Boolean {
        return if(ch.isOpen) {
            ch.remoteAddress != null
        } else {
            remoteAddress != null
        }
    }


    override suspend fun setKeepAlive(keepAlive: Boolean) {
        @Suppress("BlockingMethodInNonBlockingContext")
        ch.setOption(java.net.StandardSocketOptions.SO_KEEPALIVE, keepAlive)
    }


    override suspend fun setReceiveBufferSize(size: Int) {
        @Suppress("BlockingMethodInNonBlockingContext")
        ch.setOption(java.net.StandardSocketOptions.SO_RCVBUF, size)
    }


    override suspend fun setReuseAddress(reuse: Boolean) {
        @Suppress("BlockingMethodInNonBlockingContext")
        ch.setOption(java.net.StandardSocketOptions.SO_REUSEADDR, reuse)
    }


    override suspend fun setSendBufferSize(size: Int) {
        @Suppress("BlockingMethodInNonBlockingContext")
        ch.setOption(java.net.StandardSocketOptions.SO_SNDBUF, size)
    }


    override suspend fun setReadTimeout(timeout: Long?) {
        if(timeout==null || timeout==0L) {
            this.rTimeout = null
        } else if(timeout<0) {
            throw IllegalArgumentException("Timeout must be greater than or equal to 0")
        } else {
            this.rTimeout = timeout
        }
    }


    override suspend fun setWriteTimeout(timeout: Long?) {
        if(timeout==null || timeout==0L) {
            this.wTimeout = null
        } else if(timeout<0) {
            throw IllegalArgumentException("Timeout must be greater than or equal to 0")
        } else {
            this.wTimeout = timeout
        }
    }


    override suspend fun setTcpNoDelay(on: Boolean) {
        @Suppress("BlockingMethodInNonBlockingContext")
        ch.setOption(java.net.StandardSocketOptions.TCP_NODELAY, on)
    }
}