Skip to content

Commit

Permalink
Merge pull request #1742 from pantasystem/feature/#1741/websocket-pin…
Browse files Browse the repository at this point in the history
…g-pong

MisskeyのWebSocketの疎通確認の破壊的変更に対応
  • Loading branch information
pantasystem authored Jul 2, 2023
2 parents b74ad67 + c6acc73 commit 2067166
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ChannelAPITest {
val wssURL = "wss://misskey.io/streaming"
val logger = TestLogger.Factory()
val socket =
SocketImpl(wssURL, logger, DefaultOkHttpClientProvider())
SocketImpl(wssURL, {false}, logger, DefaultOkHttpClientProvider())
socket.blockingConnect()

var count = 0
Expand All @@ -51,7 +51,7 @@ class ChannelAPITest {
val wssURL = "wss://misskey.io/streaming"
val logger = TestLogger.Factory()
val socket =
SocketImpl(wssURL, logger, DefaultOkHttpClientProvider())
SocketImpl(wssURL, {false}, logger, DefaultOkHttpClientProvider())
val channelAPI = ChannelAPI(socket, logger)
runBlocking {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SocketImplTest {
fun testBlockingConnect() {
val wssURL = "wss://misskey.io/streaming"
val logger = TestLogger.Factory()
val socket = SocketImpl(wssURL, logger, DefaultOkHttpClientProvider())
val socket = SocketImpl(wssURL, {false} ,logger, DefaultOkHttpClientProvider())
runBlocking {
socket.blockingConnect()
assertEquals(socket.state(), Socket.State.Connected)
Expand All @@ -33,7 +33,7 @@ class SocketImplTest {

val wssURL = "wss://misskey.io/streaming"
val logger = TestLogger.Factory()
val socket = SocketImpl(wssURL, logger, DefaultOkHttpClientProvider())
val socket = SocketImpl(wssURL, {false}, logger, DefaultOkHttpClientProvider())

runBlocking {

Expand All @@ -55,7 +55,7 @@ class SocketImplTest {
val wssURL = "wss://misskey.io/streaming"
val logger = TestLogger.Factory()
val socket =
SocketImpl(wssURL, logger, DefaultOkHttpClientProvider())
SocketImpl(wssURL, {false}, logger, DefaultOkHttpClientProvider())

runBlocking {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,28 @@ package net.pantasystem.milktea.api_streaming.network
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
import net.pantasystem.milktea.api.misskey.OkHttpClientProvider
import net.pantasystem.milktea.api_streaming.*
import net.pantasystem.milktea.api_streaming.PollingJob
import net.pantasystem.milktea.api_streaming.Socket
import net.pantasystem.milktea.api_streaming.SocketMessageEventListener
import net.pantasystem.milktea.api_streaming.SocketStateEventListener
import net.pantasystem.milktea.api_streaming.StreamingEvent
import net.pantasystem.milktea.common.Logger
import net.pantasystem.milktea.common.runCancellableCatching
import okhttp3.*
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import java.util.concurrent.TimeUnit
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine

class SocketImpl(
val url: String,
var isRequirePingPong: () -> Boolean,
loggerFactory: Logger.Factory,
okHttpClientProvider: OkHttpClientProvider
) : Socket {
okHttpClientProvider: OkHttpClientProvider,
) : Socket {
val logger = loggerFactory.create("SocketImpl")

private val okHttpClient: OkHttpClient = okHttpClientProvider
Expand Down Expand Up @@ -269,8 +278,6 @@ class SocketImpl(
super.onMessage(webSocket, text)
runCancellableCatching {
pollingJob.onReceive(text)
}.onSuccess {
return
}
val e = runCancellableCatching { json.decodeFromString<StreamingEvent>(text) }.onFailure { t ->
logger.error("デコードエラー msg:$text", e = t)
Expand Down Expand Up @@ -305,7 +312,9 @@ class SocketImpl(
synchronized(this@SocketImpl) {
pollingJob.cancel()
pollingJob = PollingJob(this@SocketImpl).also {
it.startPolling(4000, 900, 12000)
if (isRequirePingPong()) {
it.startPolling(4000, 900, 12000)
}
}
mState = Socket.State.Connected
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package net.pantasystem.milktea.api_streaming

import android.util.Log
import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.datetime.Clock

const val TTL_COUNT = 3
Expand Down Expand Up @@ -40,7 +47,7 @@ internal class PollingJob(
try {
val pong = withTimeout(timeout) {
pongs.first {
it == "pong"
it.isNotBlank()
}
}
val resTime = Clock.System.now()
Expand All @@ -64,11 +71,7 @@ internal class PollingJob(
}

fun onReceive(msg: String) {
if (msg.lowercase() == "pong") {
pongs.tryEmit(msg)
} else {
throw IllegalArgumentException()
}
pongs.tryEmit(msg)
}

fun cancel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,26 @@ import net.pantasystem.milktea.api_streaming.Socket
import net.pantasystem.milktea.api_streaming.network.SocketImpl
import net.pantasystem.milktea.common.Logger
import net.pantasystem.milktea.model.account.Account
import net.pantasystem.milktea.model.account.AccountRepository
import net.pantasystem.milktea.model.account.UnauthorizedException
import net.pantasystem.milktea.model.instance.Version
import net.pantasystem.milktea.model.nodeinfo.NodeInfo
import net.pantasystem.milktea.model.nodeinfo.NodeInfoRepository
import net.pantasystem.milktea.model.nodeinfo.getVersion
import javax.inject.Inject
import net.pantasystem.milktea.data.streaming.SocketWithAccountProvider as ISocketWithAccountProvider

/**
* SocketをAccountに基づきいい感じにリソースを取得できるようにする
*/
class SocketWithAccountProviderImpl @Inject constructor(
val accountRepository: AccountRepository,
val loggerFactory: Logger.Factory,
val okHttpClientProvider: OkHttpClientProvider
val okHttpClientProvider: OkHttpClientProvider,
val nodeInfoRepository: NodeInfoRepository,
) : ISocketWithAccountProvider{

private val logger = loggerFactory.create("SocketProvider")

private val accountIdWithSocket = mutableMapOf<Long, Socket>()
private val accountIdWithSocket = mutableMapOf<Long, SocketImpl>()

/**
* accountIdとそのTokenを管理している。
Expand All @@ -47,10 +50,7 @@ class SocketWithAccountProviderImpl @Inject constructor(
logger.debug { "すでにインスタンス化済み" }
return socket
} else {
if (socket is SocketImpl) {
socket.destroy()

}
socket.destroy()
}
}

Expand All @@ -66,6 +66,11 @@ class SocketWithAccountProviderImpl @Inject constructor(

socket = SocketImpl(
url = uri,
isRequirePingPong = {
nodeInfoRepository.get(account.getHost())?.let {
!(it.type is NodeInfo.SoftwareType.Misskey.Normal && it.type.getVersion() >= Version("13.13.2"))
} ?: true
},
okHttpClientProvider = okHttpClientProvider,
loggerFactory = loggerFactory,
)
Expand Down

0 comments on commit 2067166

Please sign in to comment.