diff --git a/.gitignore b/.gitignore index 1d9fd4f1..6b5ce054 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ build/ local.properties **/build +**/*.swp # Eclipse .classpath diff --git a/gradle.properties b/gradle.properties index 1f39b913..26317815 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,3 +1,3 @@ -version=0.5.0 +version=0.5.1 plugin_version=0.5.0 signing.gnupg.keyName=5B83421E2338B907 diff --git a/ksrpc/src/commonMain/kotlin/RpcMethod.kt b/ksrpc/src/commonMain/kotlin/RpcMethod.kt index f5ed1093..69507963 100644 --- a/ksrpc/src/commonMain/kotlin/RpcMethod.kt +++ b/ksrpc/src/commonMain/kotlin/RpcMethod.kt @@ -28,6 +28,9 @@ import kotlinx.serialization.StringFormat import kotlinx.serialization.builtins.serializer internal sealed interface Transformer { + val hasContent: Boolean + get() = true + suspend fun transform(input: T, channel: SerializedService): CallData suspend fun untransform(data: CallData, channel: SerializedService): T @@ -44,6 +47,9 @@ internal sealed interface Transformer { } internal class SerializerTransformer(private val serializer: KSerializer) : Transformer { + override val hasContent: Boolean + get() = serializer != Unit.serializer() + override suspend fun transform(input: I, channel: SerializedService): CallData { return CallData.create(channel.env.serialization.encodeToString(serializer, input)) } @@ -95,11 +101,15 @@ internal interface ServiceExecutor { * A wrapper around calling into or from stubs/serialization. */ class RpcMethod internal constructor( - private val endpoint: String, + val endpoint: String, private val inputTransform: Transformer, private val outputTransform: Transformer, private val method: ServiceExecutor ) { + + internal val hasReturnType: Boolean + get() = outputTransform.hasContent + internal suspend fun call( channel: SerializedService, service: RpcService, @@ -115,7 +125,7 @@ class RpcMethod internal constructor( internal suspend fun callChannel(channel: SerializedService, input: Any?): Any? { return withContext(channel.context) { val input = inputTransform.transform(input as I, channel) - val transformedOutput = channel.call(endpoint, input) + val transformedOutput = channel.call(this@RpcMethod, input) outputTransform.untransform(transformedOutput, channel) } } diff --git a/ksrpc/src/commonMain/kotlin/channels/Connection.kt b/ksrpc/src/commonMain/kotlin/channels/Connection.kt index 03ccec96..3fddcb08 100644 --- a/ksrpc/src/commonMain/kotlin/channels/Connection.kt +++ b/ksrpc/src/commonMain/kotlin/channels/Connection.kt @@ -18,6 +18,9 @@ package com.monkopedia.ksrpc.channels import com.monkopedia.ksrpc.RpcService import com.monkopedia.ksrpc.serialized import com.monkopedia.ksrpc.toStub +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.jvm.JvmName internal interface SuspendInit { @@ -29,7 +32,7 @@ internal interface SuspendInit { * * (Meaning @KsServices can be used for both input and output of any @KsMethod) */ -interface Connection : ChannelHost, ChannelClient +interface Connection : ChannelHost, ChannelClient, SingleChannelConnection internal interface ConnectionInternal : Connection, @@ -40,6 +43,11 @@ internal interface ConnectionInternal : internal interface ConnectionProvider : ChannelHostProvider, ChannelClientProvider +/** + * A bidirectional channel that can host one service in each direction (1 host and 1 client). + */ +interface SingleChannelConnection : SingleChannelHost, SingleChannelClient + // Problems with JS compiler and serialization data class ChannelId(val id: String) @@ -55,19 +63,29 @@ internal expect interface VoidService : RpcService * This is equivalent to calling [registerDefault] for [T] instance and using * [defaultChannel] and [toStub] to create [R]. */ -suspend inline fun Connection.connect( - crossinline host: (R) -> T -) = connect { channel -> - host(channel.toStub()).serialized(env) +@OptIn(ExperimentalContracts::class) +suspend inline fun SingleChannelConnection.connect( + crossinline host: suspend (R) -> T +) { + contract { + callsInPlace(host, InvocationKind.EXACTLY_ONCE) + } + connect { channel -> + host(channel.toStub()).serialized(env) + } } /** * Raw version of [connect], performing the same functionality with [SerializedService] directly. */ @JvmName("connectSerialized") -suspend fun Connection.connect( - host: (SerializedService) -> SerializedService +@OptIn(ExperimentalContracts::class) +suspend fun SingleChannelConnection.connect( + host: suspend (SerializedService) -> SerializedService ) { + contract { + callsInPlace(host, InvocationKind.EXACTLY_ONCE) + } val recv = defaultChannel() val serializedHost = host(recv) registerDefault(serializedHost) diff --git a/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt b/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt index b73cd551..4efe1c29 100644 --- a/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt +++ b/ksrpc/src/commonMain/kotlin/channels/HttpChannels.kt @@ -43,7 +43,7 @@ suspend fun HttpClient.asWebsocketConnection(baseUrl: String, env: KsrpcEnvironm url.takeFrom(baseUrl.trimEnd('/')) url.protocol = URLProtocol.WS } - return threadSafe { context -> + return threadSafe { context -> WebsocketPacketChannel( CoroutineScope(context), context, @@ -51,6 +51,6 @@ suspend fun HttpClient.asWebsocketConnection(baseUrl: String, env: KsrpcEnvironm env ) }.also { - it.init() + (it as SuspendInit).init() } } diff --git a/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt b/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt new file mode 100644 index 00000000..970a1ce5 --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/channels/JsonRpcChannels.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.channels + +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.internal.ThreadSafeManager.threadSafe +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcWriterBase +import com.monkopedia.ksrpc.internal.jsonrpc.jsonHeader +import com.monkopedia.ksrpc.internal.jsonrpc.jsonLine +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.ByteWriteChannel +import kotlinx.coroutines.CoroutineScope + +suspend fun Pair.asJsonRpcConnection( + env: KsrpcEnvironment, + includeContentHeaders: Boolean = true +): SingleChannelConnection { + return threadSafe { context -> + JsonRpcWriterBase( + CoroutineScope(context), + context, + env, + if (includeContentHeaders) jsonHeader(env) else jsonLine(env) + ) + }.also { + (it as? SuspendInit)?.init() + } +} diff --git a/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt b/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt index 3a92fb22..5c067b2a 100644 --- a/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt +++ b/ksrpc/src/commonMain/kotlin/channels/ReadWriteChannels.kt @@ -23,6 +23,8 @@ import io.ktor.utils.io.ByteWriteChannel import kotlinx.coroutines.CoroutineScope internal const val CONTENT_LENGTH = "Content-Length" +internal const val CONTENT_TYPE = "Content-Type" +internal const val DEFAULT_CONTENT_TYPE = "application/vscode-jsonrpc; charset=utf-8" internal const val METHOD = "Method" internal const val INPUT = "Input" internal const val TYPE = "Type" diff --git a/ksrpc/src/commonMain/kotlin/channels/SerializedChannel.kt b/ksrpc/src/commonMain/kotlin/channels/SerializedChannel.kt index 1690670e..a6ee8d26 100644 --- a/ksrpc/src/commonMain/kotlin/channels/SerializedChannel.kt +++ b/ksrpc/src/commonMain/kotlin/channels/SerializedChannel.kt @@ -16,9 +16,12 @@ package com.monkopedia.ksrpc.channels import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.RpcMethod import com.monkopedia.ksrpc.RpcObject import com.monkopedia.ksrpc.RpcService import com.monkopedia.ksrpc.SuspendCloseableObservable +import com.monkopedia.ksrpc.annotation.KsMethod +import com.monkopedia.ksrpc.channels.ChannelClient.Companion.DEFAULT import com.monkopedia.ksrpc.internal.HostSerializedServiceImpl import com.monkopedia.ksrpc.rpcObject import io.ktor.utils.io.ByteReadChannel @@ -37,7 +40,7 @@ suspend inline fun ChannelHost.registerHost( /** * Register a service to be hosted on the default channel. */ -suspend inline fun ChannelHost.registerDefault( +suspend inline fun SingleChannelHost.registerDefault( service: T ) = registerDefault(service, rpcObject()) @@ -56,7 +59,7 @@ suspend fun ChannelHost.registerHost( /** * Register a service to be hosted on the default channel. */ -suspend fun ChannelHost.registerDefault( +suspend fun SingleChannelHost.registerDefault( service: T, obj: RpcObject ) { @@ -67,15 +70,35 @@ internal interface ChannelHostProvider { val host: ChannelHost? } +/** + * A wrapper around a communication pathway that can be turned into a primary + * SerializedService. + */ +interface SingleChannelHost : KsrpcEnvironment.Element { + /** + * Register the primary service to be hosted on this communication channel. + * + * The coroutine context and dispatcher on which calls are executed in on depends + * on the construction of the host. + */ + suspend fun registerDefault(service: SerializedService) +} + /** * A [SerializedChannel] that can host sub-services. * * This could be a bidirectional conduit like a [Connection], or it could be a hosting only * service such as http hosting. */ -interface ChannelHost : SerializedChannel, KsrpcEnvironment.Element { +interface ChannelHost : SerializedChannel, SingleChannelHost, KsrpcEnvironment.Element { + /** + * Add a serialized service that can receive calls on this channel with the returned + * [ChannelId]. The calls will be allowed until [close] is called. + * + * Generally this shouldn't need to be called directly, as services returned from + * [KsMethod]s are automatically registered and translated across a channel. + */ suspend fun registerHost(service: SerializedService): ChannelId - suspend fun registerDefault(service: SerializedService) } internal interface ChannelHostInternal : ChannelHost, ChannelHostProvider { @@ -87,20 +110,38 @@ internal interface ChannelClientProvider { val client: ChannelClient? } +/** + * A wrapper around a communication pathway that can be turned into a primary + * SerializedService. + */ +interface SingleChannelClient { + + /** + * Get a [SerializedService] that is the default on this client + */ + suspend fun defaultChannel(): SerializedService +} + /** * A [SerializedChannel] that can call into sub-services. * * This could be a bidirectional conduit like a [Connection], or it could be a client only * service such as http client. */ -interface ChannelClient : SerializedChannel, KsrpcEnvironment.Element { +interface ChannelClient : SerializedChannel, SingleChannelClient, KsrpcEnvironment.Element { + /** + * Takes a given channel id and creates a service wrapper to make calls on that channel. + * + * Generally this shouldn't be called directly, as services returned from [KsMethod]s + * will automatically be wrapped before being returned from stubs. + */ suspend fun wrapChannel(channelId: ChannelId): SerializedService /** * Get a [SerializedService] that is the default on this client * (i.e. using [DEFAULT] channel id). This should act as the root service for most scenarios. */ - suspend fun defaultChannel() = wrapChannel(ChannelId(DEFAULT)) + override suspend fun defaultChannel() = wrapChannel(ChannelId(DEFAULT)) companion object { /** @@ -146,6 +187,8 @@ interface SerializedService : ContextContainer, KsrpcEnvironment.Element { suspend fun call(endpoint: String, input: CallData): CallData + suspend fun call(endpoint: RpcMethod<*, *, *>, input: CallData): CallData = + call(endpoint.endpoint, input) } internal expect fun randomUuid(): String diff --git a/ksrpc/src/commonMain/kotlin/channels/SingleServiceChannel.kt b/ksrpc/src/commonMain/kotlin/channels/SingleServiceChannel.kt new file mode 100644 index 00000000..2fb5e4e2 --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/channels/SingleServiceChannel.kt @@ -0,0 +1,16 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.channels diff --git a/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt b/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt index 052e2c21..b0302d89 100644 --- a/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt +++ b/ksrpc/src/commonMain/kotlin/internal/ReadWritePacketChannel.kt @@ -66,7 +66,7 @@ internal class ReadWritePacketChannel( } } -private suspend fun ByteWriteChannel.appendLine(s: String = "") = writeStringUtf8("$s\n") +internal suspend fun ByteWriteChannel.appendLine(s: String = "") = writeStringUtf8("$s\n") @OptIn(InternalAPI::class) private suspend fun ByteWriteChannel.send( @@ -113,7 +113,7 @@ private suspend fun ByteReadChannel.readContent( } } -private suspend fun ByteReadChannel.readFields(): Map { +internal suspend fun ByteReadChannel.readFields(): Map { val fields = mutableListOf() var line = readUTF8Line() while (line == null || line.isNotEmpty()) { diff --git a/ksrpc/src/commonMain/kotlin/internal/ThreadUtils.kt b/ksrpc/src/commonMain/kotlin/internal/ThreadUtils.kt index c4f53fc1..0fbe4040 100644 --- a/ksrpc/src/commonMain/kotlin/internal/ThreadUtils.kt +++ b/ksrpc/src/commonMain/kotlin/internal/ThreadUtils.kt @@ -15,6 +15,7 @@ */ package com.monkopedia.ksrpc.internal +import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.channels.ChannelClientInternal import com.monkopedia.ksrpc.channels.ChannelClientProvider import com.monkopedia.ksrpc.channels.ChannelHostInternal @@ -37,7 +38,9 @@ internal interface ThreadSafeKeyedHost : ThreadSafeKeyed, ChannelHostInternal internal expect object ThreadSafeManager { inline fun createKey(): Any inline fun T.threadSafe(): T - inline fun threadSafe(creator: (CoroutineContext) -> T): T + inline fun threadSafe( + creator: (CoroutineContext) -> T + ): T inline fun ThreadSafeKeyedConnection.threadSafeProvider(): ConnectionProvider inline fun ThreadSafeKeyedClient.threadSafeProvider(): ChannelClientProvider diff --git a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcChannel.kt b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcChannel.kt new file mode 100644 index 00000000..bf20c492 --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcChannel.kt @@ -0,0 +1,90 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal.jsonrpc + +import com.monkopedia.ksrpc.KsrpcEnvironment.Element +import com.monkopedia.ksrpc.SuspendCloseable +import kotlinx.serialization.Required +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive + +internal interface JsonRpcChannel : SuspendCloseable, Element { + suspend fun execute(method: String, message: JsonElement?, isNotify: Boolean): JsonElement? +} + +@Serializable +internal data class JsonRpcRequest( + @Required + val jsonrpc: String = "2.0", + val method: String, + val params: JsonElement?, + val id: JsonPrimitive?, +) + +@Serializable +internal data class JsonRpcResponse( + @Required + val jsonrpc: String = "2.0", + val result: JsonElement? = null, + val error: JsonRpcError? = null, + val id: JsonPrimitive? = null, +) + +@Serializable +internal data class JsonRpcError( + val code: Int, + val message: String, + val data: JsonElement? = null +) { + companion object { + /** + * Invalid JSON was received by the server. + * An error occurred on the server while parsing the JSON text. + */ + const val PARSE_ERROR = -32700 + + /** + * The JSON sent is not a valid Request object. + */ + const val INVALID_REQUEST = -32600 + + /** + * The method does not exist / is not available. + */ + const val METHOD_NOT_FOUND = -32601 + + /** + * Invalid method parameter(s). + */ + const val INVALID_PARAMS = -32602 + + /** + * Internal JSON-RPC error. + */ + const val INTERNAL_ERROR = -32603 + + /** + * Reserved for implementation-defined server-errors. + */ + const val MIN_SERVER_ERROR = -32000 + + /** + * Reserved for implementation-defined server-errors. + */ + const val MAX_SERVER_ERROR = -32099 + } +} diff --git a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcSerializedChannel.kt b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcSerializedChannel.kt new file mode 100644 index 00000000..1bc4791f --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcSerializedChannel.kt @@ -0,0 +1,68 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal.jsonrpc + +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.RpcMethod +import com.monkopedia.ksrpc.channels.CallData +import com.monkopedia.ksrpc.channels.SerializedService +import com.monkopedia.ksrpc.channels.SuspendInit +import kotlin.coroutines.CoroutineContext +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement + +internal class JsonRpcSerializedChannel( + override val context: CoroutineContext, + private val channel: JsonRpcChannel, + override val env: KsrpcEnvironment +) : SerializedService, SuspendInit { + private val onCloseCallbacks = mutableSetOf Unit>() + private val json = (env.serialization as? Json) ?: Json + + override suspend fun call(endpoint: RpcMethod<*, *, *>, input: CallData): CallData { + return call(endpoint.endpoint, input, !endpoint.hasReturnType) + } + + override suspend fun call(endpoint: String, input: CallData): CallData { + return call(endpoint, input, false) + } + + private suspend fun call(endpoint: String, input: CallData, isNotify: Boolean): CallData { + require(!input.isBinary) { + "JsonRpc does not support binary data" + } + val message = json.decodeFromString(input.readSerialized()) + val response = channel.execute(endpoint, message, isNotify) + + return CallData.create( + if (isNotify) json.encodeToString(Unit) + else json.encodeToString(response) + ) + } + + override suspend fun close() { + channel.close() + onCloseCallbacks.forEach { + it.invoke() + } + } + + override suspend fun onClose(onClose: suspend () -> Unit) { + onCloseCallbacks.add(onClose) + } +} diff --git a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcServiceWrapper.kt b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcServiceWrapper.kt new file mode 100644 index 00000000..bff7450d --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcServiceWrapper.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal.jsonrpc + +import com.monkopedia.ksrpc.KsrpcEnvironment.Element +import com.monkopedia.ksrpc.channels.CallData +import com.monkopedia.ksrpc.channels.SerializedService +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement + +internal class JsonRpcServiceWrapper( + private val channel: SerializedService, +) : JsonRpcChannel, Element by channel { + private val json = (channel.env.serialization as? Json) ?: Json + override suspend fun execute( + method: String, + message: JsonElement?, + isNotify: Boolean + ): JsonElement? { + val response = channel.call(method, CallData.create(json.encodeToString(message))) + require(!response.isBinary) { + "JsonRpc does not support binary data" + } + return json.decodeFromString(response.readSerialized()) + } + + override suspend fun close() { + channel.close() + } +} diff --git a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcTransformer.kt b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcTransformer.kt new file mode 100644 index 00000000..9417753f --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcTransformer.kt @@ -0,0 +1,123 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal.jsonrpc + +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.channels.CONTENT_LENGTH +import com.monkopedia.ksrpc.channels.CONTENT_TYPE +import com.monkopedia.ksrpc.channels.DEFAULT_CONTENT_TYPE +import com.monkopedia.ksrpc.internal.appendLine +import com.monkopedia.ksrpc.internal.readFields +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.readFully +import io.ktor.utils.io.readUTF8Line +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement + +internal abstract class JsonRpcTransformer { + abstract val isOpen: Boolean + + abstract suspend fun send(message: JsonElement) + abstract suspend fun receive(): JsonElement? + abstract fun close(cause: Throwable?) +} + +internal fun Pair.jsonHeader( + env: KsrpcEnvironment +): JsonRpcTransformer = JsonRpcHeader(env, first, second) + +internal class JsonRpcHeader( + env: KsrpcEnvironment, + private val input: ByteReadChannel, + private val output: ByteWriteChannel, +) : JsonRpcTransformer() { + private val json = (env.serialization as? Json) ?: Json + private val sendLock = Mutex() + private val receiveLock = Mutex() + private val serializer = JsonElement.serializer() + + override val isOpen: Boolean + get() = !input.isClosedForRead + + override suspend fun send(message: JsonElement) { + sendLock.withLock { + val content = json.encodeToString(serializer, message) + val contentBytes = content.encodeToByteArray() + output.appendLine("$CONTENT_LENGTH: ${contentBytes.size}") + output.appendLine("$CONTENT_TYPE: $DEFAULT_CONTENT_TYPE") + output.appendLine() + output.writeFully(contentBytes, 0, contentBytes.size) + output.flush() + } + } + + override suspend fun receive(): JsonElement? { + receiveLock.withLock { + val params = input.readFields() + val length = params[CONTENT_LENGTH]?.toIntOrNull() ?: return null + var byteArray = ByteArray(length) + input.readFully(byteArray) + return json.decodeFromString(serializer, byteArray.decodeToString()) + } + } + + override fun close(cause: Throwable?) { + output.close(cause) + } +} + +internal fun Pair.jsonLine( + env: KsrpcEnvironment +): JsonRpcTransformer = JsonRpcLine(env, first, second) + +internal class JsonRpcLine( + env: KsrpcEnvironment, + private val input: ByteReadChannel, + private val output: ByteWriteChannel, +) : JsonRpcTransformer() { + private val json = (env.serialization as? Json) ?: Json + private val sendLock = Mutex() + private val receiveLock = Mutex() + private val serializer = JsonElement.serializer() + + override val isOpen: Boolean + get() = !input.isClosedForRead + + override suspend fun send(message: JsonElement) { + sendLock.withLock { + val content = json.encodeToString(serializer, message) + require('\n' !in content) { + "Cannot have new-lines in encoding check environment json config" + } + output.appendLine(content) + output.flush() + } + } + + override suspend fun receive(): JsonElement? { + receiveLock.withLock { + val line = input.readUTF8Line() ?: return null + return json.decodeFromString(serializer, line) + } + } + + override fun close(cause: Throwable?) { + output.close(cause) + } +} diff --git a/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt new file mode 100644 index 00000000..e0749633 --- /dev/null +++ b/ksrpc/src/commonMain/kotlin/internal/jsonrpc/JsonRpcWriterBase.kt @@ -0,0 +1,153 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal.jsonrpc + +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.RpcException +import com.monkopedia.ksrpc.asString +import com.monkopedia.ksrpc.channels.SerializedService +import com.monkopedia.ksrpc.channels.SingleChannelConnection +import com.monkopedia.ksrpc.channels.SuspendInit +import io.ktor.utils.io.core.internal.DangerousInternalIoApi +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +@OptIn(DangerousInternalIoApi::class) +internal class JsonRpcWriterBase( + private val scope: CoroutineScope, + private val context: CoroutineContext, + override val env: KsrpcEnvironment, + private val comm: JsonRpcTransformer, +) : JsonRpcChannel, SingleChannelConnection, SuspendInit { + private val json = (env.serialization as? Json) ?: Json + private var id = 1 + + private var baseChannel = CompletableDeferred() + private val completions = mutableMapOf>() + + override suspend fun init() { + scope.launch { + withContext(context) { + try { + while (comm.isOpen) { + val p = comm.receive() ?: continue + if ((p as? JsonObject)?.containsKey("method") == true) { + val request = json.decodeFromJsonElement(p) + launchRequestHandler(baseChannel.await(), request) + } else { + val response = json.decodeFromJsonElement(p) + completions.remove(response.id.toString())?.complete(response) + ?: println("Warning, no completion found for $p") + } + } + } catch (t: Throwable) { + try { + close() + } catch (t: Throwable) { + } + } + } + } + } + + private fun launchRequestHandler(channel: JsonRpcChannel, message: JsonRpcRequest) { + scope.launch(context) { + try { + val response = + channel.execute(message.method, message.params, message.id == null) + if (message.id == null) return@launch + comm.send( + json.encodeToJsonElement( + JsonRpcResponse( + result = response, + id = message.id + ) + ) + ) + } catch (t: Throwable) { + env.errorListener.onError(t) + if (message.id != null) { + comm.send( + json.encodeToJsonElement( + JsonRpcResponse( + error = JsonRpcError(JsonRpcError.INTERNAL_ERROR, t.asString), + id = message.id + ) + ) + ) + } + } + } + } + + private fun allocateResponse(isNotify: Boolean): Pair?, Int?> { + if (isNotify) return null to null + val id = id++ + return (CompletableDeferred() to id).also { + completions[JsonPrimitive(it.second).toString()] = it.first + } + } + + override suspend fun execute( + method: String, + message: JsonElement?, + isNotify: Boolean + ): JsonElement? { + val (responseHolder, id) = allocateResponse(isNotify) + val request = JsonRpcRequest( + method = method, + params = message, + id = JsonPrimitive(id) + ) + comm.send(json.encodeToJsonElement(request)) + val response = responseHolder?.await() ?: return null + if (response.error != null) { + val error = response.error.data?.let { + json.decodeFromJsonElement(it) + } ?: IllegalStateException( + "JsonRpcError(${response.error.code}): ${response.error.message}" + ) + throw error + } + return response.result + } + + override suspend fun close() { + try { + comm.close(IllegalStateException("JsonRpcWriter is shutting down")) + } catch (t: IllegalStateException) { + // Sometimes expected + } + } + + override suspend fun registerDefault(service: SerializedService) { + baseChannel.complete(JsonRpcServiceWrapper(service)) + } + + override suspend fun defaultChannel(): SerializedService { + return JsonRpcSerializedChannel(context, this, env) + } +} diff --git a/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt b/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt new file mode 100644 index 00000000..5970f9b3 --- /dev/null +++ b/ksrpc/src/commonTest/kotlin/JsonRpcTest.kt @@ -0,0 +1,461 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc + +import com.monkopedia.ksrpc.channels.DEFAULT_CONTENT_TYPE +import com.monkopedia.ksrpc.channels.SerializedService +import com.monkopedia.ksrpc.channels.SuspendInit +import com.monkopedia.ksrpc.channels.asJsonRpcConnection +import com.monkopedia.ksrpc.internal.ThreadSafeManager.threadSafe +import com.monkopedia.ksrpc.internal.appendLine +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcChannel +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcRequest +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcResponse +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcSerializedChannel +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcServiceWrapper +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcWriterBase +import com.monkopedia.ksrpc.internal.jsonrpc.jsonHeader +import com.monkopedia.ksrpc.internal.jsonrpc.jsonLine +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.close +import io.ktor.utils.io.readUTF8Line +import kotlin.coroutines.coroutineContext +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.launch +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.encodeToJsonElement + +class JsonRpcTest { + + @Test + fun testLineSender_send() = runBlockingUnit { + val expectedMessage = """{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + val sender = (inputChannel to outputChannel).jsonLine(ksrpcEnvironment { }) + sender.send( + Json.encodeToJsonElement( + JsonRpcRequest( + method = "subtract", + params = Json.decodeFromString("[42,23]"), + id = JsonPrimitive(1) + ) + ) + ) + assertEquals(expectedMessage, outputChannel.readUTF8Line()) + } + + @Test + fun testLineReceiver_send() = runBlockingUnit { + val expectedResponse = """{"jsonrpc":"2.0","result":-19,"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + val sender = (inputChannel to outputChannel).jsonLine(ksrpcEnvironment { }) + sender.send( + Json.encodeToJsonElement( + JsonRpcResponse(result = Json.decodeFromString("-19"), id = JsonPrimitive(1)) + ) + ) + assertEquals(expectedResponse, outputChannel.readUTF8Line()) + } + + @Test + fun testLineReceiver_receive() = runBlockingUnit { + val expectedMessage = """{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + inputChannel.appendLine(expectedMessage) + inputChannel.flush() + val sender = (inputChannel to outputChannel).jsonLine(ksrpcEnvironment { }) + val expectedRequest = + Json.encodeToJsonElement( + JsonRpcRequest( + method = "subtract", + params = Json.decodeFromString("[42,23]"), + id = JsonPrimitive(1) + ) + ) + assertEquals(expectedRequest, sender.receive()) + } + + @Test + fun testLineSender_receive() = runBlockingUnit { + val expectedResponse = """{"jsonrpc":"2.0","result":-19,"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + inputChannel.appendLine(expectedResponse) + inputChannel.flush() + val sender = (inputChannel to outputChannel).jsonLine(ksrpcEnvironment { }) + val expectedRequest = + Json.encodeToJsonElement( + JsonRpcResponse(result = Json.decodeFromString("-19"), id = JsonPrimitive(1)) + ) + assertEquals(expectedRequest, sender.receive()) + } + + @Test + fun testHeaderSender_send() = runBlockingUnit { + val expectedMessage = """{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + val sender = (inputChannel to outputChannel).jsonHeader(ksrpcEnvironment { }) + sender.send( + Json.encodeToJsonElement( + JsonRpcRequest( + method = "subtract", + params = Json.decodeFromString("[42,23]"), + id = JsonPrimitive(1) + ) + ) + ) + assertEquals("Content-Length: 61", outputChannel.readUTF8Line()) + assertEquals("Content-Type: $DEFAULT_CONTENT_TYPE", outputChannel.readUTF8Line()) + assertEquals("", outputChannel.readUTF8Line()) + assertEquals( + expectedMessage, + ByteArray(61).apply { outputChannel.readFully(this, 0, 61) }.decodeToString() + ) + } + + @Test + fun testHeaderReceiver_send() = runBlockingUnit { + val expectedResponse = """{"jsonrpc":"2.0","result":-19,"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + val sender = (inputChannel to outputChannel).jsonHeader(ksrpcEnvironment { }) + sender.send( + Json.encodeToJsonElement( + JsonRpcResponse(result = Json.decodeFromString("-19"), id = JsonPrimitive(1)) + ) + ) + assertEquals("Content-Length: 37", outputChannel.readUTF8Line()) + assertEquals("Content-Type: $DEFAULT_CONTENT_TYPE", outputChannel.readUTF8Line()) + assertEquals("", outputChannel.readUTF8Line()) + assertEquals( + expectedResponse, + ByteArray(37).apply { outputChannel.readFully(this, 0, 37) }.decodeToString() + ) + } + + @Test + fun testHeaderReceiver_receive() = runBlockingUnit { + val expectedMessage = """{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + inputChannel.appendLine("Content-Length: 61") + inputChannel.appendLine("Content-Type: $DEFAULT_CONTENT_TYPE") + inputChannel.appendLine() + inputChannel.appendLine(expectedMessage) + inputChannel.flush() + val sender = (inputChannel to outputChannel).jsonHeader(ksrpcEnvironment { }) + val expectedRequest = + Json.encodeToJsonElement( + JsonRpcRequest( + method = "subtract", + params = Json.decodeFromString("[42,23]"), + id = JsonPrimitive(1) + ) + ) + assertEquals(expectedRequest, sender.receive()) + } + + @Test + fun testHeaderSender_receive() = runBlockingUnit { + val expectedResponse = """{"jsonrpc":"2.0","result":-19,"id":1}""" + val outputChannel = ByteChannel() + val inputChannel = ByteChannel() + inputChannel.appendLine("Content-Length: 37") + inputChannel.appendLine("Content-Type: $DEFAULT_CONTENT_TYPE") + inputChannel.appendLine() + inputChannel.appendLine(expectedResponse) + inputChannel.flush() + val sender = (inputChannel to outputChannel).jsonHeader(ksrpcEnvironment { }) + val expectedRequest = + Json.encodeToJsonElement( + JsonRpcResponse(result = Json.decodeFromString("-19"), id = JsonPrimitive(1)) + ) + assertEquals(expectedRequest, sender.receive()) + } + + @Test + fun testSerializedChannel() = runBlockingUnit { + val params = "[42, 23]" + val fakeResponse = """{"jsonrpc": "2.0", "result": -19, "id": 1}""" + val (outOut, outIn) = createPipe() + val (inOut, inIn) = createPipe() + GlobalScope.launch(Dispatchers.Default) { + outIn.readUTF8Line() + inOut.appendLine(fakeResponse) + inOut.flush() + } + val expectedMessage = """{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}""" + val expectedResponse = """-19""" + val jsonChannel = threadSafe { context -> + JsonRpcWriterBase( + CoroutineScope(context), + context, + ksrpcEnvironment { }, + (inIn to outOut).jsonLine(ksrpcEnvironment { }) + ) + }.also { + (it as? SuspendInit)?.init() + } + assertEquals( + expectedResponse, + Json.encodeToString( + jsonChannel.execute( + "subtract", + Json.decodeFromString(params), + false + ) + ) + ) + jsonChannel.close() + try { + inOut.close() + } catch (t: Throwable) { + // Don't care + } + try { + outOut.close() + } catch (t: Throwable) { + // Don't care + } + } +} + +abstract class JsonRpcFunctionalityTest( + private val serializedChannel: suspend () -> SerializedService, + private val verifyOnChannel: suspend (SerializedService) -> Unit +) { + + @Test + fun testJsonRpcChannel() = runBlockingUnit { + val serializedChannel = serializedChannel() + val channel = JsonRpcServiceWrapper(serializedChannel) + + verifyOnChannel( + JsonRpcSerializedChannel( + coroutineContext, + channel, + ksrpcEnvironment { } + ) + ) + } + + @Test + fun testPipePassthrough() = runBlockingUnit { + val (output, input) = createPipe() + val (so, si) = createPipe() + GlobalScope.launch(Dispatchers.Default) { + val serializedChannel = serializedChannel() + val connection = (input to so).asJsonRpcConnection( + ksrpcEnvironment { + errorListener = ErrorListener { + it.printStackTrace() + } + } + ) + connection.registerDefault(serializedChannel) + } + try { + val channel = (si to output).asJsonRpcConnection( + ksrpcEnvironment { + errorListener = ErrorListener { + it.printStackTrace() + } + }, + true + ) + verifyOnChannel(channel.defaultChannel()) + } finally { + try { + input.cancel(null) + } catch (t: Throwable) { + } + try { + si.cancel(null) + } catch (t: Throwable) { + } + output.close(null) + so.close(null) + } + } + + @Test + fun testPipePassthroughWithLines() = runBlockingUnit { + val (output, input) = createPipe() + val (so, si) = createPipe() + GlobalScope.launch(Dispatchers.Default) { + val serializedChannel = serializedChannel() + val connection = (input to so).asJsonRpcConnection( + ksrpcEnvironment { + errorListener = ErrorListener { + it.printStackTrace() + } + }, + false + ) + connection.registerDefault(serializedChannel) + } + try { + val channel = (si to output).asJsonRpcConnection( + ksrpcEnvironment { + errorListener = ErrorListener { + it.printStackTrace() + } + }, + false + ) + verifyOnChannel(channel.defaultChannel()) + } finally { + try { + input.cancel(null) + } catch (t: Throwable) { + } + try { + si.cancel(null) + } catch (t: Throwable) { + } + output.close(null) + so.close(null) + } + } +} + +object JsonRpcTypeTest { + + abstract class JsonRpcTypeFunctionalityTest( + verifyOnChannel: suspend (SerializedService, FakeTestTypes) -> Unit, + private val service: FakeTestTypes = FakeTestTypes() + ) : JsonRpcFunctionalityTest( + serializedChannel = { service.serialized(ksrpcEnvironment { }) }, + verifyOnChannel = { channel -> + verifyOnChannel(channel, service) + } + ) + + class PairStrTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + service.nextReturn.value = "" + stub.rpc("Hello" to "world") + assertEquals("rpc", service.lastCall.value) + assertEquals("Hello" to "world", service.lastInput.value) + } + ) + + class MapTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + val completion = CompletableDeferred() + service.callComplete.value = completion + service.nextReturn.value = Unit + stub.mapRpc( + mutableMapOf( + "First" to MyJson("first", 1, null), + "Second" to MyJson("second", 2, 1.2f), + ) + ) + completion.await() + assertEquals("mapRpc", service.lastCall.value) + assertEquals( + mutableMapOf( + "First" to MyJson("first", 1, null), + "Second" to MyJson("second", 2, 1.2f), + ), + service.lastInput.value + ) + } + ) + + class InputIntTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + val completion = CompletableDeferred() + service.callComplete.value = completion + service.nextReturn.value = Unit + stub.inputInt(42) + completion.await() + assertEquals("inputInt", service.lastCall.value) + assertEquals(42, service.lastInput.value) + } + ) + + class InputIntListTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + val completion = CompletableDeferred() + service.callComplete.value = completion + service.nextReturn.value = Unit + stub.inputIntList(listOf(42)) + completion.await() + assertEquals("inputIntList", service.lastCall.value) + assertEquals(listOf(42), service.lastInput.value) + } + ) + + class InputIntNullableTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + val completion = CompletableDeferred() + service.callComplete.value = completion + service.nextReturn.value = Unit + stub.inputIntNullable(null) + completion.await() + assertEquals("inputIntNullable", service.lastCall.value) + assertEquals(null, service.lastInput.value) + } + ) + + class OutputIntTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + service.nextReturn.value = 42 + assertEquals(42, stub.outputInt(Unit)) + assertEquals("outputInt", service.lastCall.value) + assertEquals(Unit, service.lastInput.value) + } + ) + + class OutputIntNullableTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + service.nextReturn.value = null + assertEquals(null, stub.outputIntNullable(Unit)) + assertEquals("outputIntNullable", service.lastCall.value) + assertEquals(Unit, service.lastInput.value) + } + ) + + class ReturnTypeTest : JsonRpcTypeFunctionalityTest( + verifyOnChannel = { channel, service -> + val stub = channel.toStub() + service.nextReturn.value = MyJson("second", 2, 1.2f) + assertEquals(MyJson("second", 2, 1.2f), stub.returnType(Unit)) + assertEquals("returnType", service.lastCall.value) + assertEquals(Unit, service.lastInput.value) + } + ) +} diff --git a/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt b/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt index 33ad931d..22f1eae3 100644 --- a/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt +++ b/ksrpc/src/commonTest/kotlin/RpcTypeTest.kt @@ -21,6 +21,7 @@ import com.monkopedia.ksrpc.channels.SerializedService import kotlin.test.assertEquals import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic +import kotlinx.coroutines.CompletableDeferred import kotlinx.serialization.Serializable @Serializable @@ -61,52 +62,61 @@ class FakeTestTypes : TestTypesInterface { var lastInput: AtomicRef = atomic(null) var nextReturn: AtomicRef = atomic(null) var lastCall: AtomicRef = atomic(null) + var callComplete: AtomicRef?> = atomic(null) override suspend fun rpc(u: Pair): String { lastInput.value = u lastCall.value = "rpc" + callComplete?.value?.complete(Unit) return nextReturn.value as String } override suspend fun mapRpc(u: Map) { lastInput.value = u lastCall.value = "mapRpc" + callComplete?.value?.complete(Unit) return nextReturn.value as Unit } override suspend fun returnType(u: Unit): MyJson { lastInput.value = u lastCall.value = "returnType" + callComplete?.value?.complete(Unit) return nextReturn.value as MyJson } override suspend fun inputInt(i: Int) { lastInput.value = i lastCall.value = "inputInt" + callComplete?.value?.complete(Unit) return nextReturn.value as Unit } override suspend fun inputIntList(i: List) { lastInput.value = i lastCall.value = "inputIntList" + callComplete?.value?.complete(Unit) return nextReturn.value as Unit } override suspend fun outputInt(u: Unit): Int { lastInput.value = u lastCall.value = "outputInt" + callComplete?.value?.complete(Unit) return nextReturn.value as Int } override suspend fun inputIntNullable(i: Int?) { lastInput.value = i lastCall.value = "inputIntNullable" + callComplete?.value?.complete(Unit) return nextReturn.value as Unit } override suspend fun outputIntNullable(u: Unit): Int? { lastInput.value = u lastCall.value = "outputIntNullable" + callComplete?.value?.complete(Unit) return nextReturn.value as Int? } } @@ -168,7 +178,7 @@ object RpcTypeTest { verifyOnChannel = { channel, service -> val stub = channel.toStub() service.nextReturn.value = Unit - service.inputIntList(listOf(42)) + stub.inputIntList(listOf(42)) assertEquals("inputIntList", service.lastCall.value) assertEquals(listOf(42), service.lastInput.value) } @@ -178,7 +188,7 @@ object RpcTypeTest { verifyOnChannel = { channel, service -> val stub = channel.toStub() service.nextReturn.value = Unit - service.inputIntNullable(null) + stub.inputIntNullable(null) assertEquals("inputIntNullable", service.lastCall.value) assertEquals(null, service.lastInput.value) } diff --git a/ksrpc/src/jsMain/kotlin/internal/ThreadSafe.kt b/ksrpc/src/jsMain/kotlin/internal/ThreadSafe.kt index 62750bfd..e4a957e6 100644 --- a/ksrpc/src/jsMain/kotlin/internal/ThreadSafe.kt +++ b/ksrpc/src/jsMain/kotlin/internal/ThreadSafe.kt @@ -15,6 +15,7 @@ */ package com.monkopedia.ksrpc.internal +import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.channels.ChannelClientProvider import com.monkopedia.ksrpc.channels.ChannelHostProvider import com.monkopedia.ksrpc.channels.ConnectionProvider @@ -23,7 +24,9 @@ import kotlin.coroutines.EmptyCoroutineContext internal actual object ThreadSafeManager { actual inline fun T.threadSafe(): T = this - actual inline fun threadSafe(creator: (CoroutineContext) -> T): T = + actual inline fun threadSafe( + creator: (CoroutineContext) -> T + ): T = creator(EmptyCoroutineContext) actual inline fun createKey(): Any = Unit diff --git a/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt b/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt index e65e3c84..3a2abd60 100644 --- a/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt +++ b/ksrpc/src/jvmMain/kotlin/KsrpcUriJvm.kt @@ -59,10 +59,6 @@ actual suspend fun KsrpcUri.connect( val client = clientFactory() client.asWebsocketConnection(path, env) } - KsrpcType.WEBSOCKET -> { - val client = clientFactory() - client.asWebsocketConnection(path, env) - } } } diff --git a/ksrpc/src/jvmMain/kotlin/channels/InputOutputStreams.kt b/ksrpc/src/jvmMain/kotlin/channels/InputOutputStreams.kt index f1907131..f40e0553 100644 --- a/ksrpc/src/jvmMain/kotlin/channels/InputOutputStreams.kt +++ b/ksrpc/src/jvmMain/kotlin/channels/InputOutputStreams.kt @@ -25,6 +25,9 @@ import kotlin.concurrent.thread import kotlin.coroutines.coroutineContext import kotlinx.coroutines.Job +/** + * Helper that calls into Pair.asConnection. + */ suspend fun Pair.asConnection(env: KsrpcEnvironment): Connection { val (input, output) = this val channel = ByteChannel(autoFlush = true) @@ -34,3 +37,18 @@ suspend fun Pair.asConnection(env: KsrpcEnvironment): } return (input.toByteReadChannel(coroutineContext) to channel).asConnection(env) } + +/** + * Helper that calls into Pair.asJsonRpcConnection. + */ +suspend fun Pair.asJsonRpcConnection( + env: KsrpcEnvironment +): SingleChannelConnection { + val (input, output) = this + val channel = ByteChannel(autoFlush = true) + val job = coroutineContext[Job] + thread(start = true) { + channel.toInputStream(job).copyTo(output) + } + return (input.toByteReadChannel(coroutineContext) to channel).asJsonRpcConnection(env) +} diff --git a/ksrpc/src/jvmMain/kotlin/channels/ProcessStream.kt b/ksrpc/src/jvmMain/kotlin/channels/ProcessStream.kt index f6b537b2..a4b711ef 100644 --- a/ksrpc/src/jvmMain/kotlin/channels/ProcessStream.kt +++ b/ksrpc/src/jvmMain/kotlin/channels/ProcessStream.kt @@ -38,3 +38,27 @@ suspend fun ProcessBuilder.asConnection(env: KsrpcEnvironment): Connection { val output = process.outputStream return (input to output).asConnection(env) } + +/** + * Create a [SingleChannelConnection] that communicates over the std in/out streams of this process + * using jsonrpc. + */ +suspend fun stdInJsonRpcConnection(env: KsrpcEnvironment): SingleChannelConnection { + val input = System.`in` + val output = System.out + return (input to output).asJsonRpcConnection(env) +} + +/** + * Create a [SingleChannelConnection] that starts the process and uses the + * [Process.getInputStream] and [Process.getOutputStream] as the streams for communication using + * jsonrpc. + */ +suspend fun ProcessBuilder.asJsonRpcConnection(env: KsrpcEnvironment): SingleChannelConnection { + val process = redirectInput(ProcessBuilder.Redirect.PIPE) + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .start() + val input = process.inputStream + val output = process.outputStream + return (input to output).asJsonRpcConnection(env) +} diff --git a/ksrpc/src/jvmMain/kotlin/internal/ThreadSafe.kt b/ksrpc/src/jvmMain/kotlin/internal/ThreadSafe.kt index 62750bfd..e4a957e6 100644 --- a/ksrpc/src/jvmMain/kotlin/internal/ThreadSafe.kt +++ b/ksrpc/src/jvmMain/kotlin/internal/ThreadSafe.kt @@ -15,6 +15,7 @@ */ package com.monkopedia.ksrpc.internal +import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.channels.ChannelClientProvider import com.monkopedia.ksrpc.channels.ChannelHostProvider import com.monkopedia.ksrpc.channels.ConnectionProvider @@ -23,7 +24,9 @@ import kotlin.coroutines.EmptyCoroutineContext internal actual object ThreadSafeManager { actual inline fun T.threadSafe(): T = this - actual inline fun threadSafe(creator: (CoroutineContext) -> T): T = + actual inline fun threadSafe( + creator: (CoroutineContext) -> T + ): T = creator(EmptyCoroutineContext) actual inline fun createKey(): Any = Unit diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafe.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafe.kt index 9ecd10f0..f885e980 100644 --- a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafe.kt +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafe.kt @@ -15,22 +15,31 @@ */ package com.monkopedia.ksrpc.internal +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.SuspendCloseable +import com.monkopedia.ksrpc.SuspendCloseableObservable import com.monkopedia.ksrpc.channels.ChannelClient import com.monkopedia.ksrpc.channels.ChannelClientProvider import com.monkopedia.ksrpc.channels.ChannelHost import com.monkopedia.ksrpc.channels.ChannelHostProvider import com.monkopedia.ksrpc.channels.Connection +import com.monkopedia.ksrpc.channels.ConnectionInternal import com.monkopedia.ksrpc.channels.ConnectionProvider import com.monkopedia.ksrpc.channels.ContextContainer import com.monkopedia.ksrpc.channels.SerializedService +import com.monkopedia.ksrpc.channels.SingleChannelConnection +import com.monkopedia.ksrpc.channels.SuspendInit +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcChannel import internal.MovableInstance import internal.using import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext import kotlin.native.concurrent.DetachedObjectGraph -import kotlin.native.concurrent.TransferMode +import kotlin.native.concurrent.TransferMode.UNSAFE import kotlin.native.concurrent.attach import kotlin.native.concurrent.ensureNeverFrozen import kotlin.native.concurrent.freeze +import kotlin.reflect.KClass import kotlinx.coroutines.CloseableCoroutineDispatcher import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.withContext @@ -39,10 +48,90 @@ import kotlinx.coroutines.withContext * Tags instances that are handling thread wrapping in native code already and * avoids duplicate wrapping. */ -internal abstract class ThreadSafe( +internal class ThreadSafe( override val context: CoroutineContext, - val reference: DetachedObjectGraph -) : ContextContainer + val dispatcher: CloseableCoroutineDispatcher, + val reference: DetachedObjectGraph, + override val env: KsrpcEnvironment +) : ContextContainer, KsrpcEnvironment.Element { + private val threadSafes = MovableInstance { mutableMapOf, Any>() } + + inline fun getWrapper(): T { + if (T::class == Any::class) return this as T + if (T::class == ConnectionInternal::class) { + return threadSafes.using { + it.getOrPut(Connection::class) { + createThreadSafeConnection() + } + } as T + } + return threadSafes.using { + it.getOrPut(T::class) { + when (T::class) { + Connection::class -> createThreadSafeConnection() + ChannelHost::class -> createThreadSafeChannelHost() + ChannelClient::class -> createThreadSafeChannelClient() + SerializedService::class -> createThreadSafeService() + JsonRpcChannel::class -> createThreadSafeJsonRpcChannel() + SingleChannelConnection::class -> createSingleChannelConnection() + else -> error("${T::class} is unsupported for threadSafe operation") + } + } + } as T + } + + fun createThreadSafeService(): SerializedService { + return ThreadSafeService(this as ThreadSafe, env).freeze() + } + + fun createThreadSafeChannelClient(): ChannelClient { + return ThreadSafeChannelClient(this as ThreadSafe, env).freeze() + } + + fun createThreadSafeChannelHost(): ChannelHost { + return ThreadSafeChannelHost(this as ThreadSafe, env).freeze() + } + + fun createThreadSafeConnection(): Connection { + return ThreadSafeConnection(this as ThreadSafe, env).freeze() + } + + fun createThreadSafeJsonRpcChannel(): ThreadSafeJsonRpcChannel { + return ThreadSafeJsonRpcChannel(this as ThreadSafe, env).freeze() + } + + fun createSingleChannelConnection(): ThreadSafeSingleChannelConnection { + return ThreadSafeSingleChannelConnection(this as ThreadSafe, env) + .freeze() + } +} + +internal open class ThreadSafeUser( + val threadSafe: ThreadSafe +) : SuspendCloseable, SuspendInit, SuspendCloseableObservable, ContextContainer by threadSafe { + final override suspend fun init(): Unit = useSafe { + (it as? SuspendInit)?.init() + userCount++ + } + + final override suspend fun close() { + val needsClose = useSafe { + (it as? SuspendCloseable)?.close() + (--userCount == 0) + } + if (needsClose) { + try { + threadSafe.dispatcher.close() + } catch (t: Throwable) { + // Don't mind, just doing best to clean up. + } + } + } + + final override suspend fun onClose(onClose: suspend () -> Unit): Unit = useSafe { + (it as? SuspendCloseableObservable)?.onClose(onClose) + } +} @ThreadLocal private var threadSafeCache: MutableMap = mutableMapOf() @@ -84,21 +173,26 @@ internal actual object ThreadSafeManager { globalMap[key] = this return this } - val thread = newSingleThreadContext("thread-safe-${T::class.qualifiedName}") - val threadSafe = when (instance) { - is Connection -> createThreadSafeConnection(thread, instance) - is ChannelHost -> createThreadSafeChannelHost(thread, instance) - is ChannelClient -> createThreadSafeChannelClient(thread, instance) - is SerializedService -> createThreadSafeService(thread, instance) - else -> error("$instance is unsupported for threadSafe operation") + if (this is ThreadSafeUser<*>) { + threadSafeCache[key] = this.threadSafe + globalMap[key] = this.threadSafe + return this.threadSafe.getWrapper() } + val thread = newSingleThreadContext("thread-safe-${T::class.qualifiedName}") + val context = + ((instance as? ContextContainer)?.context ?: EmptyCoroutineContext) + thread + val env = (instance as KsrpcEnvironment.Element).env + val threadSafe = + ThreadSafe(context, thread, DetachedObjectGraph(UNSAFE) { instance }, env) threadSafeCache[key] = threadSafe globalMap[key] = threadSafe - return threadSafe as T + return threadSafe.getWrapper() } } - actual inline fun threadSafe(creator: (CoroutineContext) -> T): T { + actual inline fun threadSafe( + creator: (CoroutineContext) -> T + ): T { if (!threadSafeCacheInitialized) { threadSafeCache = mutableMapOf() threadSafeCacheInitialized = true @@ -108,71 +202,17 @@ internal actual object ThreadSafeManager { val instance = creator(thread) instance.ensureNeverFrozen() val key = (instance as? ThreadSafeKeyed)?.key ?: instance - val threadSafe = when (instance) { - is Connection -> createThreadSafeConnection(thread, instance) - is ChannelHost -> createThreadSafeChannelHost(thread, instance) - is ChannelClient -> createThreadSafeChannelClient(thread, instance) - is SerializedService -> createThreadSafeService(thread, instance) - else -> error("$instance is unsupported for threadSafe operation") - } + val context = + ((instance as? ContextContainer)?.context ?: EmptyCoroutineContext) + thread + val env = instance.env + val threadSafe = + ThreadSafe(context, thread, DetachedObjectGraph(UNSAFE) { instance }, env) globalMap[key] = threadSafe threadSafeCache[key] = threadSafe - return threadSafe as T + return threadSafe.getWrapper() } } - fun createThreadSafeService( - thread: CloseableCoroutineDispatcher, - instance: SerializedService - ): SerializedService { - val env = instance.env - val context = instance.context + thread - return ThreadSafeService( - context, - DetachedObjectGraph(TransferMode.UNSAFE) { instance }, - env - ).freeze() - } - - fun createThreadSafeChannelClient( - thread: CloseableCoroutineDispatcher, - instance: ChannelClient - ): ChannelClient { - val env = instance.env - val context = instance.context + thread - return ThreadSafeChannelClient( - context, - DetachedObjectGraph(TransferMode.UNSAFE) { instance }, - env - ).freeze() - } - - fun createThreadSafeChannelHost( - thread: CloseableCoroutineDispatcher, - instance: ChannelHost - ): ChannelHost { - val env = instance.env - val context = instance.context + thread - return ThreadSafeChannelHost( - context, - DetachedObjectGraph(TransferMode.UNSAFE) { instance }, - env - ).freeze() - } - - fun createThreadSafeConnection( - thread: CloseableCoroutineDispatcher, - instance: Connection - ): Connection { - val env = instance.env - val context = instance.context + thread - return ThreadSafeConnection( - context, - DetachedObjectGraph(TransferMode.UNSAFE) { instance }, - env - ).freeze() - } - actual inline fun ThreadSafeKeyedConnection.threadSafeProvider(): ConnectionProvider { if (!allThreadSafes.holdsLock) { throw IllegalStateException( @@ -210,7 +250,14 @@ private var instance: Any? = null @ThreadLocal private var initialized: Boolean = false -internal suspend inline fun ThreadSafe.useSafe( +@ThreadLocal +private var userCount: Int = 0 + +internal suspend inline fun ThreadSafeUser.useSafe( + crossinline usage: suspend (T) -> R +): R = threadSafe.useSafe(usage) + +internal suspend inline fun ThreadSafe.useSafe( crossinline usage: suspend (T) -> R ): R { return withContext(context) { @@ -219,9 +266,17 @@ internal suspend inline fun ThreadSafe.useSafe( } } -private inline fun ThreadSafe.ensureInitialized() { +private suspend inline fun ThreadSafe.ensureInitialized() { if (!initialized) { - instance = reference.attach() + instance = (reference as DetachedObjectGraph).attach() + userCount = 0 initialized = true + (instance as? SuspendCloseableObservable)?.onClose { + try { + dispatcher.close() + } catch (t: Throwable) { + // Don't mind, just doing best to clean up. + } + } } } diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelClient.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelClient.kt index c32dcb93..aa0dd09e 100644 --- a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelClient.kt +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelClient.kt @@ -22,14 +22,11 @@ import com.monkopedia.ksrpc.channels.ChannelClientProvider import com.monkopedia.ksrpc.channels.ChannelId import com.monkopedia.ksrpc.channels.SerializedService import com.monkopedia.ksrpc.internal.ThreadSafeManager.threadSafe -import kotlin.coroutines.CoroutineContext -import kotlin.native.concurrent.DetachedObjectGraph internal class ThreadSafeChannelClient( - context: CoroutineContext, - reference: DetachedObjectGraph, + threadSafe: ThreadSafe, override val env: KsrpcEnvironment -) : ThreadSafe(context, reference), ChannelClient { +) : ThreadSafeUser(threadSafe), ChannelClient { override suspend fun wrapChannel(channelId: ChannelId): SerializedService { return useSafe { @@ -48,18 +45,6 @@ internal class ThreadSafeChannelClient( it.close(id) } } - - override suspend fun close() { - return useSafe { - it.close() - } - } - - override suspend fun onClose(onClose: suspend () -> Unit) { - return useSafe { - it.onClose(onClose) - } - } } internal class ThreadSafeClientProvider(private val key: Any) : ChannelClientProvider { override val client: ChannelClient? diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelHost.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelHost.kt index 7f9bfa0f..35796f74 100644 --- a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelHost.kt +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeChannelHost.kt @@ -22,14 +22,11 @@ import com.monkopedia.ksrpc.channels.ChannelHostProvider import com.monkopedia.ksrpc.channels.ChannelId import com.monkopedia.ksrpc.channels.SerializedService import com.monkopedia.ksrpc.internal.ThreadSafeManager.threadSafe -import kotlin.coroutines.CoroutineContext -import kotlin.native.concurrent.DetachedObjectGraph internal class ThreadSafeChannelHost( - context: CoroutineContext, - reference: DetachedObjectGraph, + threadSafe: ThreadSafe, override val env: KsrpcEnvironment -) : ThreadSafe(context, reference), ChannelHost { +) : ThreadSafeUser(threadSafe), ChannelHost { override suspend fun registerHost(service: SerializedService): ChannelId { val threadSafeService = service.threadSafe() return useSafe { @@ -59,18 +56,6 @@ internal class ThreadSafeChannelHost( it.call(channelId, endpoint, data) } } - - override suspend fun close() { - return useSafe { - it.close() - } - } - - override suspend fun onClose(onClose: suspend () -> Unit) { - return useSafe { - it.onClose(onClose) - } - } } internal class ThreadSafeHostProvider(private val key: Any) : ChannelHostProvider { override val host: ChannelHost? diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeConnection.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeConnection.kt index 29366669..a6034cc4 100644 --- a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeConnection.kt +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeConnection.kt @@ -18,9 +18,7 @@ package com.monkopedia.ksrpc.internal import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.channels.CallData import com.monkopedia.ksrpc.channels.ChannelClient -import com.monkopedia.ksrpc.channels.ChannelClientProvider import com.monkopedia.ksrpc.channels.ChannelHost -import com.monkopedia.ksrpc.channels.ChannelHostProvider import com.monkopedia.ksrpc.channels.ChannelId import com.monkopedia.ksrpc.channels.Connection import com.monkopedia.ksrpc.channels.ConnectionInternal @@ -28,20 +26,11 @@ import com.monkopedia.ksrpc.channels.ConnectionProvider import com.monkopedia.ksrpc.channels.SerializedService import com.monkopedia.ksrpc.channels.SuspendInit import com.monkopedia.ksrpc.internal.ThreadSafeManager.threadSafe -import kotlin.coroutines.CoroutineContext -import kotlin.native.concurrent.DetachedObjectGraph internal class ThreadSafeConnection( - context: CoroutineContext, - reference: DetachedObjectGraph, + threadSafe: ThreadSafe, override val env: KsrpcEnvironment -) : ThreadSafe(context, reference), ConnectionInternal, SuspendInit { - - override suspend fun init() { - return useSafe { - (it as SuspendInit).init() - } - } +) : ThreadSafeUser(threadSafe), ConnectionInternal, SuspendInit { override suspend fun registerHost(service: SerializedService): ChannelId { val threadSafeService = service.threadSafe() @@ -63,12 +52,6 @@ internal class ThreadSafeConnection( } } - override suspend fun close() { - return useSafe { - it.close() - } - } - override suspend fun wrapChannel(channelId: ChannelId): SerializedService { return useSafe { it.wrapChannel(channelId).threadSafe() @@ -84,17 +67,11 @@ internal class ThreadSafeConnection( it.call(channelId, endpoint, data) } } - - override suspend fun onClose(onClose: suspend () -> Unit) { - return useSafe { - it.onClose(onClose) - } - } } internal class ThreadSafeConnectionProvider(private val key: Any) : ConnectionProvider { override val host: ChannelHost? - get() = (key.threadSafe() as? ChannelHostProvider)?.host + get() = (key.threadSafe() as ThreadSafe).getWrapper().host override val client: ChannelClient? - get() = (key.threadSafe() as? ChannelClientProvider)?.client + get() = (key.threadSafe() as ThreadSafe).getWrapper().client } diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeJsonRpcChannel.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeJsonRpcChannel.kt new file mode 100644 index 00000000..081c5d20 --- /dev/null +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeJsonRpcChannel.kt @@ -0,0 +1,37 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal + +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.channels.SuspendInit +import com.monkopedia.ksrpc.internal.jsonrpc.JsonRpcChannel +import kotlinx.serialization.json.JsonElement + +internal class ThreadSafeJsonRpcChannel( + threadSafe: ThreadSafe, + override val env: KsrpcEnvironment +) : ThreadSafeUser(threadSafe), JsonRpcChannel, SuspendInit { + + override suspend fun execute( + method: String, + message: JsonElement?, + isNotify: Boolean + ): JsonElement? { + return useSafe { + it.execute(method, message, isNotify) + } + } +} diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeService.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeService.kt index 8bce0139..a6939103 100644 --- a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeService.kt +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeService.kt @@ -18,30 +18,16 @@ package com.monkopedia.ksrpc.internal import com.monkopedia.ksrpc.KsrpcEnvironment import com.monkopedia.ksrpc.channels.CallData import com.monkopedia.ksrpc.channels.SerializedService -import kotlin.coroutines.CoroutineContext -import kotlin.native.concurrent.DetachedObjectGraph +import com.monkopedia.ksrpc.channels.SuspendInit internal class ThreadSafeService( - context: CoroutineContext, - reference: DetachedObjectGraph, + threadSafe: ThreadSafe, override val env: KsrpcEnvironment -) : ThreadSafe(context, reference), SerializedService { +) : ThreadSafeUser(threadSafe), SerializedService, SuspendInit { override suspend fun call(endpoint: String, input: CallData): CallData { return useSafe { it.call(endpoint, input) } } - - override suspend fun close() { - useSafe { - it.close() - } - } - - override suspend fun onClose(onClose: suspend () -> Unit) { - return useSafe { - it.onClose(onClose) - } - } } diff --git a/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeSingleChannelConnection.kt b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeSingleChannelConnection.kt new file mode 100644 index 00000000..e886d3f5 --- /dev/null +++ b/ksrpc/src/nativeMain/kotlin/internal/ThreadSafeSingleChannelConnection.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Jason Monk + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.monkopedia.ksrpc.internal + +import com.monkopedia.ksrpc.KsrpcEnvironment +import com.monkopedia.ksrpc.channels.SerializedService +import com.monkopedia.ksrpc.channels.SingleChannelConnection +import com.monkopedia.ksrpc.channels.SuspendInit +import com.monkopedia.ksrpc.internal.ThreadSafeManager.threadSafe + +internal class ThreadSafeSingleChannelConnection( + threadSafe: ThreadSafe, + override val env: KsrpcEnvironment +) : ThreadSafeUser(threadSafe), SingleChannelConnection, SuspendInit { + + override suspend fun registerDefault(service: SerializedService) { + val threadSafeService = service.threadSafe() + useSafe { + it.registerDefault(threadSafeService) + } + } + + override suspend fun defaultChannel(): SerializedService { + return useSafe { + it.defaultChannel().threadSafe() + } + } +}