diff --git a/ksrpc/build.gradle.kts b/ksrpc/build.gradle.kts index 17f8939c..3fce3dad 100644 --- a/ksrpc/build.gradle.kts +++ b/ksrpc/build.gradle.kts @@ -59,11 +59,12 @@ kotlin { sourceSets["commonMain"].dependencies { implementation("org.jetbrains.kotlinx:kotlinx-serialization-core:1.3.3") implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.3.3") - implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.3-native-mt") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4") implementation("org.jetbrains.kotlinx:atomicfu:0.17.1") implementation("io.ktor:ktor-client-core:2.0.2") implementation("io.ktor:ktor-client-websockets:2.0.2") implementation("io.ktor:ktor-http:2.0.2") + implementation("io.ktor:ktor-serialization-kotlinx-json:2.0.2") } sourceSets["commonTest"].dependencies { implementation(kotlin("test")) @@ -88,7 +89,7 @@ kotlin { } sourceSets["jvmTest"].dependencies { implementation(kotlin("test-junit")) - implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.3-native-mt") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4") implementation("io.ktor:ktor-server-core:2.0.2") implementation("io.ktor:ktor-server-netty:2.0.2") implementation("io.ktor:ktor-serialization-jackson:2.0.2") @@ -97,7 +98,7 @@ kotlin { implementation("io.ktor:ktor-client-okhttp:2.0.2") implementation("io.ktor:ktor-server-websockets:2.0.2") implementation("io.ktor:ktor-client-websockets:2.0.2") - implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.3-native-mt") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.4") } sourceSets["jsTest"].dependencies { implementation(kotlin("test-js")) @@ -109,7 +110,7 @@ kotlin { } sourceSets["nativeMain"].dependencies { implementation(kotlin("stdlib")) - implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.3-native-mt") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4") implementation("io.ktor:ktor-client-curl:2.0.2") } } @@ -128,7 +129,6 @@ kotlin.targets.withType(org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarge } } - val dokkaJavadoc = tasks.create("dokkaJavadocCustom", org.jetbrains.dokka.gradle.DokkaTask::class) { dependencies { plugins("org.jetbrains.dokka:kotlin-as-java-plugin:1.4.10.2") diff --git a/ksrpc/src/commonMain/kotlin/internal/ByteChannelUtils.kt b/ksrpc/src/commonMain/kotlin/EpochMillis.kt similarity index 81% rename from ksrpc/src/commonMain/kotlin/internal/ByteChannelUtils.kt rename to ksrpc/src/commonMain/kotlin/EpochMillis.kt index 59bc62ae..f7fc6186 100644 --- a/ksrpc/src/commonMain/kotlin/internal/ByteChannelUtils.kt +++ b/ksrpc/src/commonMain/kotlin/EpochMillis.kt @@ -13,8 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.monkopedia.ksrpc.internal +package com.monkopedia.ksrpc -import io.ktor.websocket.Frame -import io.ktor.websocket.Frame.Text -import io.ktor.websocket.readText +internal expect fun epochMillis(): Long diff --git a/ksrpc/src/commonMain/kotlin/KsrpcEnvironment.kt b/ksrpc/src/commonMain/kotlin/KsrpcEnvironment.kt index 735f8677..04bed6bd 100644 --- a/ksrpc/src/commonMain/kotlin/KsrpcEnvironment.kt +++ b/ksrpc/src/commonMain/kotlin/KsrpcEnvironment.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -28,7 +28,6 @@ interface KsrpcEnvironment { val serialization: StringFormat val defaultScope: CoroutineScope val errorListener: ErrorListener - val maxParallelReceives: Int val coroutineExceptionHandler: CoroutineExceptionHandler interface Element { @@ -64,12 +63,11 @@ fun ksrpcEnvironment(builder: KsrpcEnvironmentBuilder.() -> Unit): KsrpcEnvironm data class KsrpcEnvironmentBuilder internal constructor( override var serialization: StringFormat = Json, override var defaultScope: CoroutineScope = GlobalScope, - override var errorListener: ErrorListener = ErrorListener { }, - override var maxParallelReceives: Int = 5 + override var errorListener: ErrorListener = ErrorListener { } ) : KsrpcEnvironment { override val coroutineExceptionHandler: CoroutineExceptionHandler by lazy { CoroutineExceptionHandler { _, throwable -> errorListener.onError(throwable) } } -} \ No newline at end of file +} diff --git a/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt b/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt index 9ed9ea68..2f5773b9 100644 --- a/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt +++ b/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt b/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt index 9acdd95e..c729f28b 100644 --- a/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt +++ b/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,8 +21,8 @@ import com.monkopedia.ksrpc.internal.jsonrpc.jsonHeader import com.monkopedia.ksrpc.internal.jsonrpc.jsonLine import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel -import kotlinx.coroutines.CoroutineScope import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.CoroutineScope suspend fun Pair.asJsonRpcConnection( env: KsrpcEnvironment, diff --git a/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt b/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt index 0807350e..4d286459 100644 --- a/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt +++ b/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,8 +19,8 @@ import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.internal.ReadWritePacketChannel import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel -import kotlinx.coroutines.CoroutineScope import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.CoroutineScope internal const val CONTENT_LENGTH = "Content-Length" internal const val CONTENT_TYPE = "Content-Type" diff --git a/ksrpc/src/commonMain/kotlin/internal/HostSerializedServiceImpl.kt b/ksrpc/src/commonMain/kotlin/internal/HostSerializedServiceImpl.kt index 0c27de4b..d4337888 100644 --- a/ksrpc/src/commonMain/kotlin/internal/HostSerializedServiceImpl.kt +++ b/ksrpc/src/commonMain/kotlin/internal/HostSerializedServiceImpl.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -30,9 +30,9 @@ import com.monkopedia.ksrpc.channels.Connection import com.monkopedia.ksrpc.channels.SerializedChannel import com.monkopedia.ksrpc.channels.SerializedService import com.monkopedia.ksrpc.channels.randomUuid +import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.withContext -import kotlin.coroutines.CoroutineContext internal class HostSerializedChannelImpl( override val env: KsrpcEnvironment, @@ -55,7 +55,12 @@ internal class HostSerializedChannelImpl( serviceMap[channelId.id] ?: error("Cannot find service ${channelId.id}") } withContext(context) { - channel.call(endpoint, data) + if (endpoint.isEmpty()) { + close(channelId) + CallData.create("{}") + } else { + channel.call(endpoint, data) + } } } catch (t: Throwable) { env.errorListener.onError(t) diff --git a/ksrpc/src/commonMain/kotlin/internal/HttpSerializedChannel.kt b/ksrpc/src/commonMain/kotlin/internal/HttpSerializedChannel.kt index 96872e41..97a8ca76 100644 --- a/ksrpc/src/commonMain/kotlin/internal/HttpSerializedChannel.kt +++ b/ksrpc/src/commonMain/kotlin/internal/HttpSerializedChannel.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,6 +19,7 @@ import com.monkopedia.ksrpc.ERROR_PREFIX import com.monkopedia.ksrpc.KSRPC_BINARY import com.monkopedia.ksrpc.KSRPC_CHANNEL import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.RpcEndpointException import com.monkopedia.ksrpc.RpcFailure import com.monkopedia.ksrpc.channels.CallData import com.monkopedia.ksrpc.channels.ChannelClient @@ -35,8 +36,8 @@ import io.ktor.http.ContentType import io.ktor.http.HttpStatusCode import io.ktor.http.encodeURLPath import io.ktor.utils.io.ByteReadChannel -import kotlinx.serialization.json.Json import kotlin.coroutines.CoroutineContext +import kotlinx.serialization.json.Json internal class HttpSerializedChannel( private val httpClient: HttpClient, @@ -48,14 +49,14 @@ internal class HttpSerializedChannel( override val context: CoroutineContext = ClientChannelContext(this) + env.coroutineExceptionHandler - override suspend fun call(channelId: ChannelId, endpoint: String, input: CallData): CallData { + override suspend fun call(channelId: ChannelId, endpoint: String, data: CallData): CallData { val response = httpClient.post( "$baseStripped/call/${endpoint.encodeURLPath()}" ) { accept(ContentType.Application.Json) - headers[KSRPC_BINARY] = input.isBinary.toString() + headers[KSRPC_BINARY] = data.isBinary.toString() headers[KSRPC_CHANNEL] = channelId.id - setBody(if (input.isBinary) input.readBinary() else input.readSerialized()) + setBody(if (data.isBinary) data.readBinary() else data.readSerialized()) } response.checkErrors() if (response.headers[KSRPC_BINARY]?.toBoolean() == true) { @@ -92,5 +93,7 @@ internal suspend fun HttpResponse.checkErrors() { } else { throw IllegalStateException("Can't parse error $this") } + } else if (status == HttpStatusCode.NotFound) { + throw RpcEndpointException("Url not found $this") } } diff --git a/ksrpc/src/commonMain/kotlin/internal/MultiChannel.kt b/ksrpc/src/commonMain/kotlin/internal/MultiChannel.kt new file mode 100644 index 00000000..05cb1341 --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/internal/MultiChannel.kt @@ -0,0 +1,84 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal + +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.sync.Mutex + +internal class MultiChannel { + + private var isClosed: Boolean = false + private val lock = Mutex() + private val pending = mutableListOf>>() + private var id = 1 + + private fun checkClosed() { + require(!isClosed) { + "$this has already been closed" + } + } + + suspend fun send(id: String, response: T) { + checkClosed() + lock.lock() + try { + val hasPending = pending.consume(matcher = { it.first == id }) { (_, pendingItem) -> + pendingItem.complete(response) + } + if (!hasPending) { + error("No pending receiver for $id and $response") + } + } finally { + lock.unlock() + } + } + + suspend fun allocateReceive(): Pair> { + checkClosed() + lock.lock() + try { + val id = this.id++ + val completable = CompletableDeferred() + pending.add(id.toString() to completable) + return id to completable + } finally { + lock.unlock() + } + } + + suspend fun close(t: CancellationException? = null) { + lock.lock() + isClosed = true + pending.forEach { + it.second.completeExceptionally(t ?: CancellationException("Closing MultiChannel")) + } + lock.unlock() + } +} + +internal inline fun MutableList.consume( + crossinline matcher: (T) -> Boolean, + crossinline consumer: (T) -> Unit +): Boolean { + return removeAll { + if (matcher(it)) { + consumer(it) + true + } else false + } +} diff --git a/ksrpc/src/commonMain/kotlin/internal/Packet.kt b/ksrpc/src/commonMain/kotlin/internal/Packet.kt index c9574446..74bf9f64 100644 --- a/ksrpc/src/commonMain/kotlin/internal/Packet.kt +++ b/ksrpc/src/commonMain/kotlin/internal/Packet.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/ksrpc/src/commonMain/kotlin/internal/PacketChannelBase.kt b/ksrpc/src/commonMain/kotlin/internal/PacketChannelBase.kt index d86f9f44..f2597503 100644 --- a/ksrpc/src/commonMain/kotlin/internal/PacketChannelBase.kt +++ b/ksrpc/src/commonMain/kotlin/internal/PacketChannelBase.kt @@ -1,3 +1,18 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.monkopedia.ksrpc.internal import com.monkopedia.ksrpc.KsrpcEnvironment @@ -12,18 +27,11 @@ import io.ktor.util.encodeBase64 import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.close import io.ktor.utils.io.core.readBytes +import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.Semaphore -import kotlinx.coroutines.sync.withLock -import kotlinx.coroutines.sync.withPermit import kotlinx.coroutines.withContext -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind.EXACTLY_ONCE -import kotlin.contracts.contract -import kotlin.coroutines.CoroutineContext private const val DEFAULT_MAX_SIZE = 16 * 1024L @@ -46,15 +54,18 @@ internal abstract class PacketChannelBase( private val binaryChannelLock = Mutex() private val binaryChannels = mutableMapOf() - private val receiveChannels = arrayOfNulls(env.maxParallelReceives) - private var receiveLock = Semaphore(env.maxParallelReceives) - private val acquireChannelLock = Mutex() + + private val multiChannel = MultiChannel() init { scope.launch { withContext(context) { executeReceive(this) } + }.also { + onCloseObservers.add { + it.cancel() + } } } @@ -66,7 +77,12 @@ internal abstract class PacketChannelBase( coroutineScope.launch { if (p.binary) { val channel = getBinaryChannel(p.id) - binaryChannelLock.withLock { channel.handlePacket(p) } + binaryChannelLock.lock() + try { + channel.handlePacket(p) + } finally { + binaryChannelLock.unlock() + } removeBinaryChannelIfDone(channel) } else if (p.input) { val callData = getCallData(p) @@ -80,22 +96,14 @@ internal abstract class PacketChannelBase( sendPacket(false, p.id, p.messageId, p.endpoint, response) } else { - val channel = receiveChannels[p.messageId.toInt()] - - if (channel != null) { - channel.channel.send(p) - } else { - env.errorListener.onError( - IllegalStateException( - "Got packet $p for unexpected message id ${p.messageId}" - ) - ) - } + multiChannel.send(p.messageId, p) } } } } catch (t: Throwable) { - receiveChannels.filterNotNull().forEach { it.channel.close(t) } + t.printStackTrace() + binaryChannels.values.forEach { it.channel.close(t) } + multiChannel.close() } } @@ -179,16 +187,22 @@ internal abstract class PacketChannelBase( if (!channel.isDone) { return } - binaryChannelLock.withLock { + binaryChannelLock.lock() + try { binaryChannels.remove(channel.id) + } finally { + binaryChannelLock.unlock() } } private suspend fun getBinaryChannel(id: String): BinaryChannel { - return binaryChannelLock.withLock { - binaryChannels.getOrPut(id) { + binaryChannelLock.lock() + try { + return binaryChannels.getOrPut(id) { BinaryChannel(id) } + } finally { + binaryChannelLock.unlock() } } @@ -204,62 +218,9 @@ internal abstract class PacketChannelBase( } override suspend fun call(channelId: ChannelId, endpoint: String, data: CallData): CallData { - return withChannel { channel -> - val messageId = channel.id.toString() - scope.sendPacket(true, channelId.id, messageId, endpoint, data) - getCallData(channel.channel.receive()) - } - } - - @OptIn(ExperimentalContracts::class) - private suspend inline fun withChannel(withChannel: suspend (ReceiveChannel) -> T): T { - contract { - callsInPlace(withChannel, EXACTLY_ONCE) - } - if (receiveChannels.size == 1) { - return acquireChannelLock.withLock { - val channel = channelFor(0) - withChannel(channel) - } - } - return receiveLock.withPermit { - acquireChannelLock.lock() - val channel = try { - acquireChannel() - } finally { - acquireChannelLock.unlock() - } - try { - withChannel(channel) - } finally { - acquireChannelLock.lock() - try { - releaseChannel(channel) - } finally { - acquireChannelLock.unlock() - } - } - } - } - - private fun channelFor(index: Int) = - receiveChannels[index] ?: ReceiveChannel(index, Channel()).also { - receiveChannels[index] = it - } - - private fun acquireChannel(): ReceiveChannel { - for (i in receiveChannels.indices) { - if (receiveChannels[i]?.isLocked != true) { - return channelFor(i).also { - it.isLocked = true - } - } - } - error("Holding semaphore $receiveLock but no channels available") - } - - private fun releaseChannel(channel: ReceiveChannel) { - channel.isLocked = false + val (messageId, response) = multiChannel.allocateReceive() + scope.sendPacket(true, channelId.id, messageId.toString(), endpoint, data) + return getCallData(response.await()) } override suspend fun close(id: ChannelId) { @@ -279,14 +240,18 @@ internal abstract class PacketChannelBase( } override suspend fun close() { - callLock.withLock { + callLock.lock() + try { if (isClosed) return - receiveChannels.filterNotNull().forEach { + multiChannel.close() + binaryChannels.values.forEach { it.channel.close() } serviceChannel.close() isClosed = true onCloseObservers.forEach { it.invoke() } + } finally { + callLock.unlock() } } @@ -331,10 +296,8 @@ internal abstract class PacketChannelBase( } } - private data class ReceiveChannel( - val id: Int, - val channel: Channel - ) { - var isLocked: Boolean = false - } + private data class PendingPacket( + val receivedAt: Long, + val packet: Packet + ) } diff --git a/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt b/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt index df232b2e..79a57416 100644 --- a/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt +++ b/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,12 +19,12 @@ import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.channels.CONTENT_LENGTH import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.errors.IOException import io.ktor.utils.io.readFully import io.ktor.utils.io.readUTF8Line import io.ktor.utils.io.writeStringUtf8 import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock import kotlinx.serialization.StringFormat import kotlinx.serialization.decodeFromString import kotlinx.serialization.encodeToString @@ -39,14 +39,20 @@ internal class ReadWritePacketChannel( private val receiveLock = Mutex() override suspend fun send(packet: Packet) { - sendLock.withLock { + sendLock.lock() + try { write.send(packet, env.serialization) + } finally { + sendLock.unlock() } } override suspend fun receive(): Packet { - receiveLock.withLock { + receiveLock.lock() + try { return read.readPacket(env.serialization) + } finally { + receiveLock.unlock() } } @@ -88,12 +94,10 @@ private suspend fun ByteReadChannel.readContent( internal suspend fun ByteReadChannel.readFields(): Map { val fields = mutableListOf() - var line = readUTF8Line() - while (line == null || line.isNotEmpty()) { - if (line != null) { - fields.add(line) - } - line = readUTF8Line() + var line = readUTF8Line() ?: throw IOException("$this is closed for reading") + while (line.isNotEmpty()) { + fields.add(line) + line = readUTF8Line() ?: "" } return parseParams(fields) } diff --git a/ksrpc/src/commonMain/kotlin/internal/WebsocketPacketChannel.kt b/ksrpc/src/commonMain/kotlin/internal/WebsocketPacketChannel.kt index 74cde7e3..2f0e4023 100644 --- a/ksrpc/src/commonMain/kotlin/internal/WebsocketPacketChannel.kt +++ b/ksrpc/src/commonMain/kotlin/internal/WebsocketPacketChannel.kt @@ -16,16 +16,14 @@ package com.monkopedia.ksrpc.internal import com.monkopedia.ksrpc.KsrpcEnvironment +import io.ktor.serialization.kotlinx.KotlinxWebsocketSerializationConverter +import io.ktor.utils.io.charsets.Charsets import io.ktor.websocket.DefaultWebSocketSession -import io.ktor.websocket.Frame -import io.ktor.websocket.Frame.Text import io.ktor.websocket.close -import io.ktor.websocket.readText +import io.ktor.websocket.serialization.receiveDeserializedBase +import io.ktor.websocket.serialization.sendSerializedBase import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock -import kotlinx.serialization.decodeFromString -import kotlinx.serialization.encodeToString internal class WebsocketPacketChannel( scope: CoroutineScope, @@ -34,6 +32,7 @@ internal class WebsocketPacketChannel( ) : PacketChannelBase(scope, env) { private val sendLock = Mutex() private val receiveLock = Mutex() + private val converter = KotlinxWebsocketSerializationConverter(env.serialization) // Use socket max frame with some room for padding. // Divide by 2 to allow for manual base64-ing. @@ -41,17 +40,18 @@ internal class WebsocketPacketChannel( get() = socketSession.maxFrameSize / 2 - 1024 override suspend fun send(packet: Packet) { - sendLock.withLock { - val serialized = env.serialization.encodeToString(packet) - socketSession.send(Text(serialized)) + sendLock.lock() + try { + socketSession.sendSerializedBase(packet, converter, Charsets.UTF_8) + } finally { + sendLock.unlock() } } override suspend fun receive(): Packet { receiveLock.lock() try { - val packetText = socketSession.incoming.receive() - return env.serialization.decodeFromString(packetText.expectText()) + return socketSession.receiveDeserializedBase(converter, Charsets.UTF_8) } finally { receiveLock.unlock() } @@ -62,11 +62,3 @@ internal class WebsocketPacketChannel( socketSession.close() } } - -private fun Frame.expectText(): String { - if (this is Text) { - return readText() - } else { - throw IllegalStateException("Unexpected frame $this") - } -} diff --git a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt index ddfc4e9a..7228251b 100644 --- a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt +++ b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,9 +20,10 @@ import com.monkopedia.ksrpc.RpcException import com.monkopedia.ksrpc.asString import com.monkopedia.ksrpc.channels.SerializedService import com.monkopedia.ksrpc.channels.SingleChannelConnection +import com.monkopedia.ksrpc.internal.MultiChannel +import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Deferred import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import kotlinx.serialization.json.Json @@ -31,7 +32,6 @@ import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.encodeToJsonElement -import kotlin.coroutines.CoroutineContext internal class JsonRpcWriterBase( private val scope: CoroutineScope, @@ -40,10 +40,9 @@ internal class JsonRpcWriterBase( private val comm: JsonRpcTransformer ) : JsonRpcChannel, SingleChannelConnection { private val json = (env.serialization as? Json) ?: Json - private var id = 1 private var baseChannel = CompletableDeferred() - private val completions = mutableMapOf>() + private val multiChannel = MultiChannel() init { scope.launch { @@ -56,8 +55,7 @@ internal class JsonRpcWriterBase( launchRequestHandler(baseChannel.await(), request) } else { val response = json.decodeFromJsonElement(p) - completions.remove(response.id.toString())?.complete(response) - ?: println("Warning, no completion found for $p") + multiChannel.send(response.id.toString(), response) } } } catch (t: Throwable) { @@ -101,27 +99,19 @@ internal class JsonRpcWriterBase( } } - private fun allocateResponse(isNotify: Boolean): Pair?, Int?> { - if (isNotify) return null to null - val id = id++ - return (CompletableDeferred() to id).also { - completions[JsonPrimitive(it.second).toString()] = it.first - } - } - override suspend fun execute( method: String, message: JsonElement?, isNotify: Boolean ): JsonElement? { - val (responseHolder, id) = allocateResponse(isNotify) + val (id, pending) = if (isNotify) null to null else multiChannel.allocateReceive() val request = JsonRpcRequest( method = method, params = message, id = JsonPrimitive(id) ) comm.send(json.encodeToJsonElement(request)) - val response = responseHolder?.await() ?: return null + val response = pending?.await() ?: return null if (response.error != null) { val error = response.error.data?.let { json.decodeFromJsonElement(it) @@ -139,6 +129,11 @@ internal class JsonRpcWriterBase( } catch (t: IllegalStateException) { // Sometimes expected } + try { + multiChannel.close() + } catch (t: Throwable) { + // Thats fine, just pending messages getting unhappy. + } } override suspend fun registerDefault(service: SerializedService) { diff --git a/ksrpc/src/commonTest/kotlin/BinaryTest.kt b/ksrpc/src/commonTest/kotlin/BinaryTest.kt index e75c6e9f..6a87abc5 100644 --- a/ksrpc/src/commonTest/kotlin/BinaryTest.kt +++ b/ksrpc/src/commonTest/kotlin/BinaryTest.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/ksrpc/src/commonTest/kotlin/ConnectionTest.kt b/ksrpc/src/commonTest/kotlin/ConnectionTest.kt index 79a0aee7..8476028f 100644 --- a/ksrpc/src/commonTest/kotlin/ConnectionTest.kt +++ b/ksrpc/src/commonTest/kotlin/ConnectionTest.kt @@ -26,7 +26,6 @@ import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch -import kotlin.native.concurrent.SharedImmutable @KsService interface ChildInterface : RpcService { @@ -71,14 +70,18 @@ class ConnectionTest { } ) + private var pendingFinish: CompletableDeferred? = null + @Test fun testReverse() = executePipe( serviceJob = { c -> val service = c.defaultChannel().toStub() service.basicCall("Hello world") + pendingFinish?.complete(Unit) }, clientJob = { c -> val callComplete = CompletableDeferred() + pendingFinish = CompletableDeferred() c.registerDefault(object : PrimaryInterface { override suspend fun basicCall(input: String): String { callComplete.complete(input) @@ -92,6 +95,7 @@ class ConnectionTest { error("Not implemented") }) assertEquals("Hello world", callComplete.await()) + pendingFinish?.await() } ) @@ -144,10 +148,12 @@ class ConnectionTest { error("Not implemented") }) clientService.basicCall("Hello world") + pendingFinish?.complete(Unit) }, clientJob = { c -> val service = c.defaultChannel().toStub() val callComplete = CompletableDeferred() + pendingFinish = CompletableDeferred() c.registerDefault(object : PrimaryInterface { override suspend fun basicCall(input: String): String { return "Client: ${service.basicCall(input)}".also { callComplete.complete(it) } @@ -160,6 +166,7 @@ class ConnectionTest { error("Not implemented") }) assertEquals("Client: Respond: Hello world", callComplete.await()) + pendingFinish?.await() } ) @@ -198,9 +205,11 @@ class ConnectionTest { return "Respond: $input" } }) + pendingFinish?.complete(Unit) }, clientJob = { c -> val callComplete = CompletableDeferred() + pendingFinish = CompletableDeferred() c.registerDefault(object : PrimaryInterface { override suspend fun basicCall(input: String): String = error("Not implemented") @@ -213,6 +222,7 @@ class ConnectionTest { error("Not implemented") }) assertEquals("Respond: Hello world", callComplete.await()) + pendingFinish?.await() } ) @@ -253,9 +263,11 @@ class ConnectionTest { assertEquals("Second service: Hello trees", secondService.rpc("Hello trees")) assertEquals("First service: Hello trees", firstService.rpc("Hello trees")) service.basicCall("Done") + pendingFinish?.complete(Unit) }, clientJob = { c -> val callComplete = CompletableDeferred() + pendingFinish = CompletableDeferred() c.registerDefault(object : PrimaryInterface { override suspend fun basicCall(input: String): String = input.also { callComplete.complete(input) } @@ -271,6 +283,7 @@ class ConnectionTest { } }) assertEquals("Done", callComplete.await()) + pendingFinish?.await() } ) diff --git a/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt b/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt index 97e3377b..6b0e0446 100644 --- a/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt +++ b/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -29,6 +29,9 @@ import com.monkopedia.ksrpc.internal.jsonrpc.jsonLine import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.close import io.ktor.utils.io.readUTF8Line +import kotlin.coroutines.coroutineContext +import kotlin.test.Test +import kotlin.test.assertEquals import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.GlobalScope @@ -40,9 +43,6 @@ import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.encodeToJsonElement -import kotlin.coroutines.coroutineContext -import kotlin.test.Test -import kotlin.test.assertEquals class JsonRpcTest { @@ -214,7 +214,11 @@ class JsonRpcTest { val jsonChannel = JsonRpcWriterBase( CoroutineScope(jsonChannelContext), jsonChannelContext, - ksrpcEnvironment { }, + ksrpcEnvironment { + errorListener = ErrorListener { + it.printStackTrace() + } + }, (inIn to outOut).jsonLine(ksrpcEnvironment { }) ) assertEquals( diff --git a/ksrpc/src/commonTest/kotlin/MyJson.kt b/ksrpc/src/commonTest/kotlin/MyJson.kt index 69dc240f..fd639d1f 100644 --- a/ksrpc/src/commonTest/kotlin/MyJson.kt +++ b/ksrpc/src/commonTest/kotlin/MyJson.kt @@ -1,3 +1,18 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.monkopedia.ksrpc import kotlinx.serialization.Serializable @@ -7,4 +22,4 @@ data class MyJson( val str: String, val int: Int, val nFloat: Float? -) \ No newline at end of file +) diff --git a/ksrpc/src/commonTest/kotlin/RpcFunctionalityTest.kt b/ksrpc/src/commonTest/kotlin/RpcFunctionalityTest.kt index 96f37301..57030b85 100644 --- a/ksrpc/src/commonTest/kotlin/RpcFunctionalityTest.kt +++ b/ksrpc/src/commonTest/kotlin/RpcFunctionalityTest.kt @@ -15,7 +15,6 @@ */ package com.monkopedia.ksrpc -import com.monkopedia.ksrpc.channels.Connection import com.monkopedia.ksrpc.channels.SerializedService import com.monkopedia.ksrpc.channels.asConnection import com.monkopedia.ksrpc.channels.asWebsocketConnection @@ -25,16 +24,15 @@ import io.ktor.client.HttpClient import io.ktor.client.plugins.websocket.WebSockets import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel +import kotlin.test.Test import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch -import kotlin.test.Test abstract class RpcFunctionalityTest( private val supportedTypes: List = TestType.values().toList(), - private val serializedChannel: suspend () -> SerializedService, - private val verifyOnChannel: suspend (SerializedService) -> Unit + private val serializedChannel: suspend CoroutineScope.() -> SerializedService, + private val verifyOnChannel: suspend CoroutineScope.(SerializedService) -> Unit ) { enum class TestType { SERIALIZE, @@ -47,7 +45,7 @@ abstract class RpcFunctionalityTest( fun testSerializePassthrough() = runBlockingUnit { if (TestType.SERIALIZE !in supportedTypes) return@runBlockingUnit val serializedChannel = serializedChannel() - val channel = HostSerializedChannelImpl(ksrpcEnvironment { }) + val channel = HostSerializedChannelImpl(createEnv()) channel.registerDefault(serializedChannel) verifyOnChannel(channel.asClient.defaultChannel()) @@ -61,16 +59,12 @@ abstract class RpcFunctionalityTest( launch(Dispatchers.Default) { val serializedChannel = serializedChannel() val connection = (si to output).asConnection( - ksrpcEnvironment { - errorListener = ErrorListener { - it.printStackTrace() - } - } + createEnv() ) connection.registerDefault(serializedChannel) } try { - verifyOnChannel((input to so).asConnection(ksrpcEnvironment { }).defaultChannel()) + verifyOnChannel((input to so).asConnection(createEnv()).defaultChannel()) } finally { try { input.cancel(null) @@ -95,17 +89,13 @@ abstract class RpcFunctionalityTest( val routing = testServe( path, serializedChannel, - ksrpcEnvironment { - errorListener = ErrorListener { - it.printStackTrace() - } - } + createEnv() ) routing() }, test = { val client = HttpClient() - client.asConnection("http://localhost:$it$path", ksrpcEnvironment { }) + client.asConnection("http://localhost:$it$path", createEnv()) .use { channel -> verifyOnChannel(channel.defaultChannel()) } @@ -123,24 +113,22 @@ abstract class RpcFunctionalityTest( testServeWebsocket( path, serializedChannel, - ksrpcEnvironment { - errorListener = ErrorListener { - it.printStackTrace() - } - } + createEnv() ) }, test = { val client = HttpClient { install(WebSockets) } - client.asWebsocketConnection("http://localhost:$it$path", ksrpcEnvironment { }) + client.asWebsocketConnection("http://localhost:$it$path", createEnv()) .use { channel -> verifyOnChannel(channel.defaultChannel()) } } ) } + + protected open fun createEnv() = ksrpcEnvironment { } } internal expect fun runBlockingUnit(function: suspend CoroutineScope.() -> Unit) diff --git a/ksrpc/src/commonTest/kotlin/RpcServiceTest.kt b/ksrpc/src/commonTest/kotlin/RpcServiceTest.kt index 519c99a8..93070d91 100644 --- a/ksrpc/src/commonTest/kotlin/RpcServiceTest.kt +++ b/ksrpc/src/commonTest/kotlin/RpcServiceTest.kt @@ -19,10 +19,13 @@ import com.monkopedia.ksrpc.annotation.KsMethod import com.monkopedia.ksrpc.annotation.KsService import com.monkopedia.ksrpc.channels.CallData import com.monkopedia.ksrpc.channels.SerializedService +import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.launch import kotlinx.serialization.builtins.PairSerializer import kotlinx.serialization.builtins.serializer import kotlinx.serialization.json.Json @@ -184,3 +187,51 @@ class RpcServiceTwoCallsTest : RpcFunctionalityTest( ) } ) + +private var cancelSignal: CompletableDeferred>? = null + +class RpcServiceCancelTest : RpcFunctionalityTest( + serializedChannel = { + val channel: TestInterface = object : TestInterface { + override suspend fun rpc(u: Pair): String { + val completion = CompletableDeferred() + cancelSignal?.complete(completion) + completion.await() + return "${u.first} ${u.second}" + } + } + channel.serialized(ksrpcEnvironment { }) + }, + verifyOnChannel = { serializedChannel -> + val stub = serializedChannel.toStub() + val rpcJob = launch { + try { + stub.rpc("Hello" to "world") + cancelSignal!!.completeExceptionally(RuntimeException("Test failure")) + } finally { + } + } + val continueSignal = cancelSignal!!.await() + rpcJob.cancel() + continueSignal.complete(Unit) + + cancelSignal = CompletableDeferred>().also { + launch { + it.await().complete(Unit) + } + } + + assertEquals( + "Hello world", + stub.rpc("Hello" to "world") + ) + }, + supportedTypes = TestType.values().toList() - TestType.SERIALIZE +) { + @BeforeTest + fun setup() { + cancelSignal = CompletableDeferred() + } + + @Test fun testNothing() = Unit +} diff --git a/ksrpc/src/commonTest/kotlin/RpcSubserviceTest.kt b/ksrpc/src/commonTest/kotlin/RpcSubserviceTest.kt index 6861297f..e08e8c9a 100644 --- a/ksrpc/src/commonTest/kotlin/RpcSubserviceTest.kt +++ b/ksrpc/src/commonTest/kotlin/RpcSubserviceTest.kt @@ -20,6 +20,7 @@ import com.monkopedia.ksrpc.annotation.KsService import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull +import kotlinx.coroutines.CompletableDeferred @KsService interface TestSubInterface : RpcService { @@ -84,3 +85,49 @@ class RpcSubserviceTwoCallsTest : RpcFunctionalityTest( ) } ) + +private var closeCompletion: CompletableDeferred? = null + +class RpcSubserviceCloseTest : RpcFunctionalityTest( + serializedChannel = { + val channel: TestRootInterface = object : TestRootInterface { + override suspend fun rpc(u: Pair): String { + return "${u.first} ${u.second}" + } + + override suspend fun subservice(prefix: String): TestSubInterface { + return object : TestSubInterface { + override suspend fun rpc(u: Pair): String { + return "$prefix ${u.first} ${u.second}" + } + + override suspend fun close() { + closeCompletion?.complete(Unit) + } + } + } + } + channel.serialized(ksrpcEnvironment { }) + }, + verifyOnChannel = { serializedChannel -> + val stub = serializedChannel.toStub() + closeCompletion = CompletableDeferred() + assertEquals( + "oh, Hello world", + stub.subservice("oh,").run { + rpc("Hello" to "world").also { + close() + } + } + ) + closeCompletion?.await() + } +) { + override fun createEnv(): KsrpcEnvironment { + return ksrpcEnvironment { + errorListener = ErrorListener { + closeCompletion?.completeExceptionally(it) + } + } + } +} diff --git a/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt b/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt index 2e997736..3b5b1b82 100644 --- a/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt +++ b/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,10 +18,10 @@ package com.monkopedia.ksrpc import com.monkopedia.ksrpc.annotation.KsMethod import com.monkopedia.ksrpc.annotation.KsService import com.monkopedia.ksrpc.channels.SerializedService +import kotlin.test.assertEquals import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic import kotlinx.coroutines.CompletableDeferred -import kotlin.test.assertEquals @KsService interface TestTypesInterface : RpcService { diff --git a/ksrpc/src/jsMain/kotlin/EpochMillis.kt b/ksrpc/src/jsMain/kotlin/EpochMillis.kt new file mode 100644 index 00000000..0ec8931a --- /dev/null +++ b/ksrpc/src/jsMain/kotlin/EpochMillis.kt @@ -0,0 +1,20 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc + +import kotlin.js.Date + +internal actual fun epochMillis(): Long = Date.now().toLong() diff --git a/ksrpc/src/jvmMain/kotlin/EpochMillis.kt b/ksrpc/src/jvmMain/kotlin/EpochMillis.kt new file mode 100644 index 00000000..7224cf26 --- /dev/null +++ b/ksrpc/src/jvmMain/kotlin/EpochMillis.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc + +internal actual fun epochMillis(): Long = System.currentTimeMillis() diff --git a/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt b/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt index 6fe33e96..d2041e89 100644 --- a/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt +++ b/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt @@ -16,7 +16,6 @@ package com.monkopedia.ksrpc import com.monkopedia.ksrpc.channels.ChannelClient -import com.monkopedia.ksrpc.channels.Connection import com.monkopedia.ksrpc.channels.asConnection import com.monkopedia.ksrpc.channels.asWebsocketConnection import com.monkopedia.ksrpc.channels.registerDefault diff --git a/ksrpc/src/jvmMain/kotlin/ServiceApp.kt b/ksrpc/src/jvmMain/kotlin/ServiceApp.kt index 4b24c826..4f220666 100644 --- a/ksrpc/src/jvmMain/kotlin/ServiceApp.kt +++ b/ksrpc/src/jvmMain/kotlin/ServiceApp.kt @@ -27,8 +27,9 @@ import com.monkopedia.ksrpc.channels.stdInConnection import io.ktor.server.application.install import io.ktor.server.engine.embeddedServer import io.ktor.server.netty.Netty -import io.ktor.server.plugins.cors.routing.* -import io.ktor.server.routing.* +import io.ktor.server.plugins.cors.routing.CORS +import io.ktor.server.routing.Routing +import io.ktor.server.routing.routing import java.net.ServerSocket import kotlin.concurrent.thread import kotlin.system.exitProcess @@ -83,12 +84,13 @@ abstract class ServiceApp(val appName: String) : CliktCommand() { } runBlocking { for (h in http) { - val routes = serve("/${appName.decapitalize()}", createChannel(), env) embeddedServer(Netty, h) { install(CORS) { anyHost() } - routing(routes) + routing { + createRouting() + } }.start() } if (stdOut) { @@ -97,6 +99,10 @@ abstract class ServiceApp(val appName: String) : CliktCommand() { } } + protected open fun Routing.createRouting() { + serve("/${appName.decapitalize()}", createChannel(), env)(this) + } + open val env: KsrpcEnvironment by lazy { ksrpcEnvironment {} } diff --git a/ksrpc/src/jvmMain/kotlin/channels/HttpStream.kt b/ksrpc/src/jvmMain/kotlin/channels/HttpStream.kt index 60dc7060..1dd0eb7d 100644 --- a/ksrpc/src/jvmMain/kotlin/channels/HttpStream.kt +++ b/ksrpc/src/jvmMain/kotlin/channels/HttpStream.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,6 +27,7 @@ import com.monkopedia.ksrpc.internal.WebsocketPacketChannel import com.monkopedia.ksrpc.serialized import io.ktor.http.HttpStatusCode import io.ktor.http.decodeURLPart +import io.ktor.server.application.ApplicationCall import io.ktor.server.application.call import io.ktor.server.request.receive import io.ktor.server.response.respond @@ -34,30 +35,34 @@ import io.ktor.server.response.respondBytesWriter import io.ktor.server.routing.Routing import io.ktor.server.routing.post import io.ktor.server.websocket.webSocket +import io.ktor.util.pipeline.PipelineContext import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.copyTo import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch import kotlinx.serialization.json.Json -suspend inline fun Routing.serve( +inline fun Routing.serve( basePath: String, service: T, env: KsrpcEnvironment -) = serve(basePath, service.serialized(env), env) +) = serve(basePath, service.serialized(env), env).invoke(this) -suspend inline fun Routing.serveWebsocket( +inline fun Routing.serveWebsocket( basePath: String, service: T, env: KsrpcEnvironment -) = serveWebsocket(basePath, service.serialized(env), env) +) = serveWebsocket(basePath, service.serialized(env), env) -suspend fun serve( +fun serve( basePath: String, service: SerializedService, env: KsrpcEnvironment ): Routing.() -> Unit { val channel = HostSerializedChannelImpl(env).also { - it.registerDefault(service) + env.defaultScope.launch { + it.registerDefault(service) + } } return { serve(basePath, channel, env) @@ -70,31 +75,52 @@ fun Routing.serve( env: KsrpcEnvironment ) { val baseStripped = basePath.trimEnd('/') + post("$baseStripped/call/") { + runCatching(env) { + execCall(channel, "") + } + } post("$baseStripped/call/{method}") { - try { + runCatching(env) { val method = call.parameters["method"]?.decodeURLPart() ?: error("Missing method") - val content = if (call.request.headers[KSRPC_BINARY]?.toBoolean() == true) { - CallData.create(call.receive()) - } else { - CallData.create(call.receive()) - } - val channelId = call.request.headers[KSRPC_CHANNEL] ?: ChannelClient.DEFAULT - val response = channel.call(ChannelId(channelId), method, content) - if (response.isBinary) { - call.response.headers.append(KSRPC_BINARY, "true") - call.respondBytesWriter { - response.readBinary().copyTo(this) - } - } else { - call.respond(response.readSerialized()) - } - } catch (t: Throwable) { - env.errorListener.onError(t) - call.respond( - ERROR_PREFIX + Json.encodeToString(RpcFailure.serializer(), RpcFailure(t.asString)) - ) - call.response.status(HttpStatusCode.InternalServerError) + execCall(channel, method) + } + } +} + +private suspend fun PipelineContext.execCall( + channel: SerializedChannel, + method: String +) { + val content = if (call.request.headers[KSRPC_BINARY]?.toBoolean() == true) { + CallData.create(call.receive()) + } else { + CallData.create(call.receive()) + } + val channelId = call.request.headers[KSRPC_CHANNEL] ?: ChannelClient.DEFAULT + val response = channel.call(ChannelId(channelId), method, content) + if (response.isBinary) { + call.response.headers.append(KSRPC_BINARY, "true") + call.respondBytesWriter { + response.readBinary().copyTo(this) } + } else { + call.respond(response.readSerialized()) + } +} + +private suspend inline fun PipelineContext.runCatching( + env: KsrpcEnvironment, + exec: suspend () -> Unit +) { + try { + exec() + } catch (t: Throwable) { + env.errorListener.onError(t) + call.respond( + ERROR_PREFIX + Json.encodeToString(RpcFailure.serializer(), RpcFailure(t.asString)) + ) + call.response.status(HttpStatusCode.InternalServerError) } } diff --git a/ksrpc/src/jvmTest/kotlin/OverlappingTest.kt b/ksrpc/src/jvmTest/kotlin/OverlappingTest.kt index 351faef0..ccf594ff 100644 --- a/ksrpc/src/jvmTest/kotlin/OverlappingTest.kt +++ b/ksrpc/src/jvmTest/kotlin/OverlappingTest.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,11 +15,11 @@ */ package com.monkopedia.ksrpc +import kotlin.test.assertEquals import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.coroutineScope -import kotlin.test.assertEquals class OverlappingTest() : OverlappingTestBase() diff --git a/ksrpc/src/jvmTest/kotlin/TestUtilsJvm.kt b/ksrpc/src/jvmTest/kotlin/TestUtilsJvm.kt index b571e384..34011374 100644 --- a/ksrpc/src/jvmTest/kotlin/TestUtilsJvm.kt +++ b/ksrpc/src/jvmTest/kotlin/TestUtilsJvm.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,9 @@ package com.monkopedia.ksrpc import com.monkopedia.ksrpc.channels.SerializedService +import com.monkopedia.ksrpc.channels.serve as jvmServe import com.monkopedia.ksrpc.channels.serveWebsocket +import io.ktor.serialization.kotlinx.KotlinxWebsocketSerializationConverter import io.ktor.server.application.install import io.ktor.server.engine.ApplicationEngine import io.ktor.server.engine.embeddedServer @@ -26,6 +28,7 @@ import io.ktor.server.websocket.WebSockets import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel +import java.util.concurrent.CountDownLatch import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.DelicateCoroutinesApi @@ -34,8 +37,7 @@ import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.runBlocking -import java.util.concurrent.CountDownLatch -import com.monkopedia.ksrpc.channels.serve as jvmServe +import kotlinx.serialization.json.Json var PORT = 8081 @@ -51,7 +53,9 @@ actual suspend inline fun httpTest( try { serverCompletion.complete( embeddedServer(Netty, port) { - install(WebSockets) + install(WebSockets) { + contentConverter = KotlinxWebsocketSerializationConverter(Json) + } routing { runBlocking { serve() diff --git a/ksrpc/src/nativeMain/kotlin/EpochMillis.kt b/ksrpc/src/nativeMain/kotlin/EpochMillis.kt new file mode 100644 index 00000000..2d3cae6a --- /dev/null +++ b/ksrpc/src/nativeMain/kotlin/EpochMillis.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc + +import kotlinx.cinterop.alloc +import kotlinx.cinterop.memScoped +import kotlinx.cinterop.ptr +import platform.posix.gettimeofday +import platform.posix.timeval + +internal actual fun epochMillis(): Long = memScoped { + val timeVal = alloc() + gettimeofday(timeVal.ptr, null) + (timeVal.tv_sec * 1000) + (timeVal.tv_usec / 1000) +} diff --git a/ksrpc/src/nativeMain/kotlin/KsrpcUriNative.kt b/ksrpc/src/nativeMain/kotlin/KsrpcUriNative.kt index aa63ab82..59c1de4f 100644 --- a/ksrpc/src/nativeMain/kotlin/KsrpcUriNative.kt +++ b/ksrpc/src/nativeMain/kotlin/KsrpcUriNative.kt @@ -1,12 +1,12 @@ /* * Copyright 2021 Jason Monk - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * https://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/ksrpc/src/nativeTest/kotlin/NativeTestUtil.kt b/ksrpc/src/nativeTest/kotlin/NativeTestUtil.kt index 256b863b..70f34d4b 100644 --- a/ksrpc/src/nativeTest/kotlin/NativeTestUtil.kt +++ b/ksrpc/src/nativeTest/kotlin/NativeTestUtil.kt @@ -18,6 +18,8 @@ package com.monkopedia.ksrpc import com.monkopedia.ksrpc.channels.SerializedService import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel +import kotlin.test.Test +import kotlin.test.fail import kotlinx.cinterop.IntVar import kotlinx.cinterop.allocArray import kotlinx.cinterop.get @@ -28,8 +30,6 @@ import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext import platform.posix.pipe import platform.posix.pthread_self -import kotlin.test.Test -import kotlin.test.fail actual suspend inline fun httpTest( crossinline serve: suspend Routing.() -> Unit,