diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/RelayerTests.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/RelayerTests.xcscheme new file mode 100644 index 000000000..f4fc9ed05 --- /dev/null +++ b/.swiftpm/xcode/xcshareddata/xcschemes/RelayerTests.xcscheme @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/Example/RelayIntegrationTests/RelayClientEndToEndTests.swift b/Example/RelayIntegrationTests/RelayClientEndToEndTests.swift index e52094049..4eccebd9b 100644 --- a/Example/RelayIntegrationTests/RelayClientEndToEndTests.swift +++ b/Example/RelayIntegrationTests/RelayClientEndToEndTests.swift @@ -52,14 +52,16 @@ final class RelayClientEndToEndTests: XCTestCase { socketAuthenticator: socketAuthenticator ) - let socketConnectionHandler = AutomaticSocketConnectionHandler(socket: socket, logger: logger) + let socketStatusProvider = SocketStatusProvider(socket: socket, logger: logger) + let socketConnectionHandler = AutomaticSocketConnectionHandler(socket: socket, subscriptionsTracker: SubscriptionsTracker(), logger: logger, socketStatusProvider: socketStatusProvider) let dispatcher = Dispatcher( socketFactory: webSocketFactory, relayUrlFactory: urlFactory, networkMonitor: networkMonitor, socket: socket, logger: logger, - socketConnectionHandler: socketConnectionHandler + socketConnectionHandler: socketConnectionHandler, + socketStatusProvider: socketStatusProvider ) let keychain = KeychainStorageMock() let relayClient = RelayClientFactory.create( diff --git a/Sources/Events/EventsClient.swift b/Sources/Events/EventsClient.swift index da0dd30bd..5d1664c29 100644 --- a/Sources/Events/EventsClient.swift +++ b/Sources/Events/EventsClient.swift @@ -14,25 +14,28 @@ public class EventsClient: EventsClientProtocol { private let logger: ConsoleLogging private var stateStorage: TelemetryStateStorage private let messageEventsStorage: MessageEventsStorage + private let initEventsStorage: InitEventsStorage init( eventsCollector: EventsCollector, eventsDispatcher: EventsDispatcher, logger: ConsoleLogging, stateStorage: TelemetryStateStorage, - messageEventsStorage: MessageEventsStorage + messageEventsStorage: MessageEventsStorage, + initEventsStorage: InitEventsStorage ) { self.eventsCollector = eventsCollector self.eventsDispatcher = eventsDispatcher self.logger = logger self.stateStorage = stateStorage self.messageEventsStorage = messageEventsStorage + self.initEventsStorage = initEventsStorage - if stateStorage.telemetryEnabled { - Task { await sendStoredEvents() } - } else { + if !stateStorage.telemetryEnabled { self.eventsCollector.storage.clearErrorEvents() } + saveInitEvent() + Task { await sendStoredEvents() } } public func setLogging(level: LoggingLevel) { @@ -63,6 +66,30 @@ public class EventsClient: EventsClientProtocol { messageEventsStorage.saveMessageEvent(event) } + public func saveInitEvent() { + logger.debug("Will store an init event") + + let bundleId = Bundle.main.bundleIdentifier ?? "Unknown" + let clientId = (try? Networking.interactor.getClientId()) ?? "Unknown" + let userAgent = EnvironmentInfo.userAgent + + let props = InitEvent.Props( + properties: InitEvent.Properties( + clientId: clientId, + userAgent: userAgent + ) + ) + + let event = InitEvent( + eventId: UUID().uuidString, + bundleId: bundleId, + timestamp: Int64(Date().timeIntervalSince1970 * 1000), + props: props + ) + + initEventsStorage.saveInitEvent(event) + } + // Public method to set telemetry enabled or disabled public func setTelemetryEnabled(_ enabled: Bool) { stateStorage.telemetryEnabled = enabled @@ -78,24 +105,26 @@ public class EventsClient: EventsClientProtocol { let traceEvents = eventsCollector.storage.fetchErrorEvents() let messageEvents = messageEventsStorage.fetchMessageEvents() + let initEvents = initEventsStorage.fetchInitEvents() - guard !traceEvents.isEmpty || !messageEvents.isEmpty else { return } + guard !traceEvents.isEmpty || !messageEvents.isEmpty || !initEvents.isEmpty else { return } var combinedEvents: [AnyCodable] = [] - // Wrap trace events combinedEvents.append(contentsOf: traceEvents.map { AnyCodable($0) }) - // Wrap message events combinedEvents.append(contentsOf: messageEvents.map { AnyCodable($0) }) + combinedEvents.append(contentsOf: initEvents.map { AnyCodable($0) }) + logger.debug("Will send combined events") do { let success: Bool = try await eventsDispatcher.executeWithRetry(events: combinedEvents) if success { logger.debug("Combined events sent successfully") - self.eventsCollector.storage.clearErrorEvents() - self.messageEventsStorage.clearMessageEvents() + eventsCollector.storage.clearErrorEvents() + messageEventsStorage.clearMessageEvents() + initEventsStorage.clearInitEvents() } } catch { logger.debug("Failed to send events after multiple attempts: \(error)") diff --git a/Sources/Events/EventsClientFactory.swift b/Sources/Events/EventsClientFactory.swift index b45df2f26..c7cf66f2d 100644 --- a/Sources/Events/EventsClientFactory.swift +++ b/Sources/Events/EventsClientFactory.swift @@ -19,7 +19,8 @@ public class EventsClientFactory { eventsDispatcher: eventsDispatcher, logger: logger, stateStorage: UserDefaultsTelemetryStateStorage(), - messageEventsStorage: UserDefaultsMessageEventsStorage() + messageEventsStorage: UserDefaultsMessageEventsStorage(), + initEventsStorage: UserDefaultsInitEventsStorage() ) } } diff --git a/Sources/Events/InitEvent.swift b/Sources/Events/InitEvent.swift new file mode 100644 index 000000000..9c1718d29 --- /dev/null +++ b/Sources/Events/InitEvent.swift @@ -0,0 +1,25 @@ +import Foundation + +struct InitEvent: Codable { + struct Props: Codable { + let event: String = "INIT" + let type: String = "None" + let properties: Properties + } + + struct Properties: Codable { + let clientId: String + let userAgent: String + + // Custom CodingKeys to map Swift property names to JSON keys + enum CodingKeys: String, CodingKey { + case clientId = "client_id" + case userAgent = "user_agent" + } + } + + let eventId: String + let bundleId: String + let timestamp: Int64 + let props: Props +} diff --git a/Sources/Events/InitEventsStorage.swift b/Sources/Events/InitEventsStorage.swift new file mode 100644 index 000000000..ae6216554 --- /dev/null +++ b/Sources/Events/InitEventsStorage.swift @@ -0,0 +1,42 @@ +import Foundation + +protocol InitEventsStorage { + func saveInitEvent(_ event: InitEvent) + func fetchInitEvents() -> [InitEvent] + func clearInitEvents() +} + + +class UserDefaultsInitEventsStorage: InitEventsStorage { + private let initEventsKey = "com.walletconnect.sdk.initEvents" + private let maxEvents = 100 + + func saveInitEvent(_ event: InitEvent) { + // Fetch existing events from UserDefaults + var existingEvents = fetchInitEvents() + existingEvents.append(event) + + // Ensure we keep only the last 100 events + if existingEvents.count > maxEvents { + existingEvents = Array(existingEvents.suffix(maxEvents)) + } + + // Save updated events back to UserDefaults + if let encoded = try? JSONEncoder().encode(existingEvents) { + UserDefaults.standard.set(encoded, forKey: initEventsKey) + } + } + + func fetchInitEvents() -> [InitEvent] { + if let data = UserDefaults.standard.data(forKey: initEventsKey), + let events = try? JSONDecoder().decode([InitEvent].self, from: data) { + // Return only the last 100 events + return Array(events.suffix(maxEvents)) + } + return [] + } + + func clearInitEvents() { + UserDefaults.standard.removeObject(forKey: initEventsKey) + } +} diff --git a/Sources/WalletConnectRelay/Dispatching.swift b/Sources/WalletConnectRelay/Dispatching.swift index 3af72ce97..9d2198cee 100644 --- a/Sources/WalletConnectRelay/Dispatching.swift +++ b/Sources/WalletConnectRelay/Dispatching.swift @@ -22,11 +22,10 @@ final class Dispatcher: NSObject, Dispatching { private let relayUrlFactory: RelayUrlFactory private let networkMonitor: NetworkMonitoring private let logger: ConsoleLogging - - private let socketConnectionStatusPublisherSubject = CurrentValueSubject(.disconnected) + private let socketStatusProvider: SocketStatusProviding var socketConnectionStatusPublisher: AnyPublisher { - socketConnectionStatusPublisherSubject.eraseToAnyPublisher() + socketStatusProvider.socketConnectionStatusPublisher } var networkConnectionStatusPublisher: AnyPublisher { @@ -45,18 +44,18 @@ final class Dispatcher: NSObject, Dispatching { networkMonitor: NetworkMonitoring, socket: WebSocketConnecting, logger: ConsoleLogging, - socketConnectionHandler: SocketConnectionHandler + socketConnectionHandler: SocketConnectionHandler, + socketStatusProvider: SocketStatusProviding ) { self.socketConnectionHandler = socketConnectionHandler self.relayUrlFactory = relayUrlFactory self.networkMonitor = networkMonitor self.logger = logger - self.socket = socket + self.socketStatusProvider = socketStatusProvider super.init() setUpWebSocketSession() - setUpSocketConnectionObserving() } func send(_ string: String, completion: @escaping (Error?) -> Void) { @@ -74,12 +73,17 @@ final class Dispatcher: NSObject, Dispatching { return send(string, completion: completion) } + // Always connect when there is a message to be sent + if !socket.isConnected { + socketConnectionHandler.handleInternalConnect() + } + var cancellable: AnyCancellable? cancellable = Publishers.CombineLatest(socketConnectionStatusPublisher, networkConnectionStatusPublisher) .filter { $0.0 == .connected && $0.1 == .connected } .setFailureType(to: NetworkError.self) .timeout(.seconds(defaultTimeout), scheduler: concurrentQueue, customError: { .connectionFailed }) - .sink(receiveCompletion: { [unowned self] result in + .sink(receiveCompletion: { result in switch result { case .failure(let error): cancellable?.cancel() @@ -128,18 +132,5 @@ extension Dispatcher { } } - private func setUpSocketConnectionObserving() { - socket.onConnect = { [unowned self] in - self.socketConnectionStatusPublisherSubject.send(.connected) - } - socket.onDisconnect = { [unowned self] error in - self.socketConnectionStatusPublisherSubject.send(.disconnected) - if error != nil { - self.socket.request.url = relayUrlFactory.create() - } - Task(priority: .high) { - await self.socketConnectionHandler.handleDisconnection() - } - } - } + } diff --git a/Sources/WalletConnectRelay/RelayClient.swift b/Sources/WalletConnectRelay/RelayClient.swift index f51f69c84..a3088bca9 100644 --- a/Sources/WalletConnectRelay/RelayClient.swift +++ b/Sources/WalletConnectRelay/RelayClient.swift @@ -20,8 +20,6 @@ public final class RelayClient { case subscriptionIdNotFound } - var subscriptions: [String: String] = [:] - public var isSocketConnected: Bool { return dispatcher.isSocketConnected } @@ -49,12 +47,14 @@ public final class RelayClient { private var requestAcknowledgePublisher: AnyPublisher { requestAcknowledgePublisherSubject.eraseToAnyPublisher() } + private var publishers = [AnyCancellable]() private let clientIdStorage: ClientIdStoring private var dispatcher: Dispatching private let rpcHistory: RPCHistory private let logger: ConsoleLogging + private let subscriptionsTracker: SubscriptionsTracking private let concurrentQueue = DispatchQueue(label: "com.walletconnect.sdk.relay_client", qos: .utility, attributes: .concurrent) @@ -69,13 +69,16 @@ public final class RelayClient { dispatcher: Dispatching, logger: ConsoleLogging, rpcHistory: RPCHistory, - clientIdStorage: ClientIdStoring + clientIdStorage: ClientIdStoring, + subscriptionsTracker: SubscriptionsTracking ) { self.logger = logger self.dispatcher = dispatcher self.rpcHistory = rpcHistory self.clientIdStorage = clientIdStorage + self.subscriptionsTracker = subscriptionsTracker setUpBindings() + setupConnectionSubscriptions() } private func setUpBindings() { @@ -84,6 +87,18 @@ public final class RelayClient { } } + private func setupConnectionSubscriptions() { + socketConnectionStatusPublisher + .sink { [unowned self] status in + guard status == .connected else { return } + let topics = subscriptionsTracker.getTopics() + Task(priority: .high) { + try await batchSubscribe(topics: topics) + } + } + .store(in: &publishers) + } + public func setLogging(level: LoggingLevel) { logger.setLogging(level: level) } @@ -183,14 +198,13 @@ public final class RelayClient { } public func unsubscribe(topic: String, completion: ((Error?) -> Void)?) { - guard let subscriptionId = subscriptions[topic] else { + guard let subscriptionId = subscriptionsTracker.getSubscription(for: topic) else { completion?(Errors.subscriptionIdNotFound) return } logger.debug("Unsubscribing from topic: \(topic)") let rpc = Unsubscribe(params: .init(id: subscriptionId, topic: topic)) - let request = rpc - .asRPCRequest() + let request = rpc.asRPCRequest() let message = try! request.asJSONEncodedString() rpcHistory.deleteAll(forTopic: topic) dispatcher.protectedSend(message) { [weak self] error in @@ -198,9 +212,7 @@ public final class RelayClient { self?.logger.debug("Failed to unsubscribe from topic") completion?(error) } else { - self?.concurrentQueue.async(flags: .barrier) { - self?.subscriptions[topic] = nil - } + self?.subscriptionsTracker.removeSubscription(for: topic) completion?(nil) } } @@ -213,15 +225,13 @@ public final class RelayClient { .filter { $0.0 == requestId } .sink { [unowned self] (_, subscriptionIds) in cancellable?.cancel() - concurrentQueue.async(flags: .barrier) { [unowned self] in - logger.debug("Subscribed to topics: \(topics)") - guard topics.count == subscriptionIds.count else { - logger.warn("Number of topics in (batch)subscribe does not match number of subscriptions") - return - } - for i in 0..() private let concurrentQueue = DispatchQueue(label: "com.walletconnect.sdk.automatic_socket_connection", qos: .utility, attributes: .concurrent) + var reconnectionAttempts = 0 + let maxImmediateAttempts = 3 + var periodicReconnectionInterval: TimeInterval = 5.0 + var reconnectionTimer: DispatchSourceTimer? + var isConnecting = false + init( socket: WebSocketConnecting, networkMonitor: NetworkMonitoring = NetworkMonitor(), appStateObserver: AppStateObserving = AppStateObserver(), backgroundTaskRegistrar: BackgroundTaskRegistering = BackgroundTaskRegistrar(), - logger: ConsoleLogging + subscriptionsTracker: SubscriptionsTracking, + logger: ConsoleLogging, + socketStatusProvider: SocketStatusProviding ) { self.appStateObserver = appStateObserver self.socket = socket self.networkMonitor = networkMonitor self.backgroundTaskRegistrar = backgroundTaskRegistrar self.logger = logger + self.subscriptionsTracker = subscriptionsTracker + self.socketStatusProvider = socketStatusProvider setUpStateObserving() setUpNetworkMonitoring() - - connect() - + setUpSocketStatusObserving() } func connect() { - // Attempt to handle connection + // Start the connection process + isConnecting = true socket.connect() + } - // Start a timer for the fallback mechanism - let timer = DispatchSource.makeTimerSource(queue: concurrentQueue) - timer.schedule(deadline: .now() + .seconds(defaultTimeout)) - timer.setEventHandler { [weak self] in - guard let self = self else { - timer.cancel() - return - } - if !self.socket.isConnected { - self.logger.debug("Connection timed out, will rety to connect...") - retryToConnect() + private func setUpSocketStatusObserving() { + socketStatusProvider.socketConnectionStatusPublisher + .sink { [unowned self] status in + switch status { + case .connected: + isConnecting = false + reconnectionAttempts = 0 // Reset reconnection attempts on successful connection + stopPeriodicReconnectionTimer() // Stop any ongoing periodic reconnection attempts + case .disconnected: + if isConnecting { + // Handle reconnection logic + handleFailedConnectionAndReconnectIfNeeded() + } else { + Task(priority: .high) { + await handleDisconnection() + } + } + } } - timer.cancel() + .store(in: &publishers) + } + + private func handleFailedConnectionAndReconnectIfNeeded() { + if reconnectionAttempts < maxImmediateAttempts { + reconnectionAttempts += 1 + logger.debug("Immediate reconnection attempt \(reconnectionAttempts) of \(maxImmediateAttempts)") + socket.connect() + } else { + logger.debug("Max immediate reconnection attempts reached. Switching to periodic reconnection every \(periodicReconnectionInterval) seconds.") + startPeriodicReconnectionTimerIfNeeded() + } + } + + private func stopPeriodicReconnectionTimer() { + reconnectionTimer?.cancel() + reconnectionTimer = nil + } + + private func startPeriodicReconnectionTimerIfNeeded() { + guard reconnectionTimer == nil else {return} + + reconnectionTimer = DispatchSource.makeTimerSource(queue: concurrentQueue) + let initialDelay: DispatchTime = .now() + periodicReconnectionInterval + + reconnectionTimer?.schedule(deadline: initialDelay, repeating: periodicReconnectionInterval) + + reconnectionTimer?.setEventHandler { [unowned self] in + logger.debug("Periodic reconnection attempt...") + socket.connect() // Attempt to reconnect + + // The socketConnectionStatusPublisher handler will stop the timer and reset states if connection is successful } - timer.resume() + + reconnectionTimer?.resume() } private func setUpStateObserving() { @@ -72,9 +122,9 @@ class AutomaticSocketConnectionHandler { } private func setUpNetworkMonitoring() { - networkMonitor.networkConnectionStatusPublisher.sink { [weak self] networkConnectionStatus in + networkMonitor.networkConnectionStatusPublisher.sink { [unowned self] networkConnectionStatus in if networkConnectionStatus == .connected { - self?.reconnectIfNeeded() + reconnectIfNeeded() } } .store(in: &publishers) @@ -90,22 +140,21 @@ class AutomaticSocketConnectionHandler { socket.disconnect() } - private func retryToConnect() { - if !socket.isConnected { + func reconnectIfNeeded() { + // Check if client has active subscriptions and only then attempt to reconnect + if !socket.isConnected && subscriptionsTracker.isSubscribed() { connect() } } - - private func reconnectIfNeeded() { - if !socket.isConnected { - socket.connect() - } - } } // MARK: - SocketConnectionHandler extension AutomaticSocketConnectionHandler: SocketConnectionHandler { + func handleInternalConnect() { + connect() + } + func handleConnect() throws { throw Errors.manualSocketConnectionForbidden } diff --git a/Sources/WalletConnectRelay/SocketConnectionHandler/ManualSocketConnectionHandler.swift b/Sources/WalletConnectRelay/SocketConnectionHandler/ManualSocketConnectionHandler.swift index 04152bd21..daf2d1c76 100644 --- a/Sources/WalletConnectRelay/SocketConnectionHandler/ManualSocketConnectionHandler.swift +++ b/Sources/WalletConnectRelay/SocketConnectionHandler/ManualSocketConnectionHandler.swift @@ -1,7 +1,6 @@ import Foundation class ManualSocketConnectionHandler: SocketConnectionHandler { - private let socket: WebSocketConnecting private let logger: ConsoleLogging private let defaultTimeout: Int = 60 @@ -37,6 +36,11 @@ class ManualSocketConnectionHandler: SocketConnectionHandler { socket.disconnect() } + func handleInternalConnect() { + // No operation + } + + func handleDisconnection() async { // No operation // ManualSocketConnectionHandler does not support reconnection logic diff --git a/Sources/WalletConnectRelay/SocketConnectionHandler/SocketConnectionHandler.swift b/Sources/WalletConnectRelay/SocketConnectionHandler/SocketConnectionHandler.swift index 4ac3046dd..808ee43df 100644 --- a/Sources/WalletConnectRelay/SocketConnectionHandler/SocketConnectionHandler.swift +++ b/Sources/WalletConnectRelay/SocketConnectionHandler/SocketConnectionHandler.swift @@ -1,7 +1,10 @@ import Foundation protocol SocketConnectionHandler { + /// handles connection request from the sdk consumes func handleConnect() throws + /// handles connection request from sdk's internal function + func handleInternalConnect() func handleDisconnect(closeCode: URLSessionWebSocketTask.CloseCode) throws func handleDisconnection() async } diff --git a/Sources/WalletConnectRelay/SocketConnectionHandler/WebSocket.swift b/Sources/WalletConnectRelay/SocketConnectionHandler/WebSocket.swift index fd9d96a56..d4042ee19 100644 --- a/Sources/WalletConnectRelay/SocketConnectionHandler/WebSocket.swift +++ b/Sources/WalletConnectRelay/SocketConnectionHandler/WebSocket.swift @@ -14,3 +14,43 @@ public protocol WebSocketConnecting: AnyObject { public protocol WebSocketFactory { func create(with url: URL) -> WebSocketConnecting } + +#if DEBUG +class WebSocketMock: WebSocketConnecting { + var request: URLRequest = URLRequest(url: URL(string: "wss://relay.walletconnect.com")!) + + var onText: ((String) -> Void)? + var onConnect: (() -> Void)? + var onDisconnect: ((Error?) -> Void)? + var sendCallCount: Int = 0 + var isConnected: Bool = false + + func connect() { + isConnected = true + onConnect?() + } + + func disconnect() { + isConnected = false + onDisconnect?(nil) + } + + func write(string: String, completion: (() -> Void)?) { + sendCallCount+=1 + } +} +#endif + +#if DEBUG +class WebSocketFactoryMock: WebSocketFactory { + private let webSocket: WebSocketMock + + init(webSocket: WebSocketMock) { + self.webSocket = webSocket + } + + func create(with url: URL) -> WebSocketConnecting { + return webSocket + } +} +#endif diff --git a/Sources/WalletConnectRelay/SocketStatusProvider.swift b/Sources/WalletConnectRelay/SocketStatusProvider.swift new file mode 100644 index 000000000..1003fe01e --- /dev/null +++ b/Sources/WalletConnectRelay/SocketStatusProvider.swift @@ -0,0 +1,48 @@ + +import Foundation +import Combine + +protocol SocketStatusProviding { + var socketConnectionStatusPublisher: AnyPublisher { get } +} + +class SocketStatusProvider: SocketStatusProviding { + private var socket: WebSocketConnecting + private let logger: ConsoleLogging + private let socketConnectionStatusPublisherSubject = CurrentValueSubject(.disconnected) + + var socketConnectionStatusPublisher: AnyPublisher { + socketConnectionStatusPublisherSubject.eraseToAnyPublisher() + } + + init(socket: WebSocketConnecting, + logger: ConsoleLogging) { + self.socket = socket + self.logger = logger + setUpSocketConnectionObserving() + } + + private func setUpSocketConnectionObserving() { + socket.onConnect = { [unowned self] in + self.socketConnectionStatusPublisherSubject.send(.connected) + } + socket.onDisconnect = { [unowned self] error in + logger.debug("Socket disconnected with error: \(error?.localizedDescription ?? "Unknown error")") + self.socketConnectionStatusPublisherSubject.send(.disconnected) + } + } +} + +#if DEBUG +final class SocketStatusProviderMock: SocketStatusProviding { + private var socketConnectionStatusPublisherSubject = PassthroughSubject() + + var socketConnectionStatusPublisher: AnyPublisher { + socketConnectionStatusPublisherSubject.eraseToAnyPublisher() + } + + func simulateConnectionStatus(_ status: SocketConnectionStatus) { + socketConnectionStatusPublisherSubject.send(status) + } +} +#endif diff --git a/Sources/WalletConnectRelay/SubscriptionsTracker.swift b/Sources/WalletConnectRelay/SubscriptionsTracker.swift new file mode 100644 index 000000000..2684de202 --- /dev/null +++ b/Sources/WalletConnectRelay/SubscriptionsTracker.swift @@ -0,0 +1,82 @@ +import Foundation + +protocol SubscriptionsTracking { + func setSubscription(for topic: String, id: String) + func getSubscription(for topic: String) -> String? + func removeSubscription(for topic: String) + func isSubscribed() -> Bool + func getTopics() -> [String] +} + +public final class SubscriptionsTracker: SubscriptionsTracking { + private var subscriptions: [String: String] = [:] + private let concurrentQueue = DispatchQueue(label: "com.walletconnect.sdk.subscriptions_tracker", attributes: .concurrent) + + func setSubscription(for topic: String, id: String) { + concurrentQueue.async(flags: .barrier) { [unowned self] in + self.subscriptions[topic] = id + } + } + + func getSubscription(for topic: String) -> String? { + var result: String? + concurrentQueue.sync { [unowned self] in + result = subscriptions[topic] + } + return result + } + + func removeSubscription(for topic: String) { + concurrentQueue.async(flags: .barrier) { [unowned self] in + subscriptions[topic] = nil + } + } + + func isSubscribed() -> Bool { + var result = false + concurrentQueue.sync { [unowned self] in + result = !subscriptions.isEmpty + } + return result + } + + func getTopics() -> [String] { + var topics: [String] = [] + concurrentQueue.sync { [unowned self] in + topics = Array(subscriptions.keys) + } + return topics + } +} + +#if DEBUG +final class SubscriptionsTrackerMock: SubscriptionsTracking { + var isSubscribedReturnValue: Bool = false + private var subscriptions: [String: String] = [:] + + func setSubscription(for topic: String, id: String) { + subscriptions[topic] = id + } + + func getSubscription(for topic: String) -> String? { + return subscriptions[topic] + } + + func removeSubscription(for topic: String) { + subscriptions[topic] = nil + } + + func isSubscribed() -> Bool { + return isSubscribedReturnValue + } + + func reset() { + subscriptions.removeAll() + isSubscribedReturnValue = false + } + + func getTopics() -> [String] { + return Array(subscriptions.keys) + } +} +#endif diff --git a/Sources/WalletConnectSign/Auth/Services/AuthResponseTopicResubscriptionService.swift b/Sources/WalletConnectSign/Auth/Services/AuthResponseTopicResubscriptionService.swift index 4d8af4005..e17ef6426 100644 --- a/Sources/WalletConnectSign/Auth/Services/AuthResponseTopicResubscriptionService.swift +++ b/Sources/WalletConnectSign/Auth/Services/AuthResponseTopicResubscriptionService.swift @@ -30,19 +30,14 @@ class AuthResponseTopicResubscriptionService { self.logger = logger self.authResponseTopicRecordsStore = authResponseTopicRecordsStore cleanExpiredRecordsIfNeeded() - setupConnectionSubscriptions() + subscribeResponsTopics() } - func setupConnectionSubscriptions() { - networkingInteractor.socketConnectionStatusPublisher - .sink { [unowned self] status in - guard status == .connected else { return } - let topics = authResponseTopicRecordsStore.getAll().map{$0.topic} - Task(priority: .high) { - try await networkingInteractor.batchSubscribe(topics: topics) - } - } - .store(in: &publishers) + func subscribeResponsTopics() { + let topics = authResponseTopicRecordsStore.getAll().map{$0.topic} + Task(priority: .background) { + try await networkingInteractor.batchSubscribe(topics: topics) + } } func cleanExpiredRecordsIfNeeded() { diff --git a/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift b/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift index a18c922c8..203cd6d16 100644 --- a/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift +++ b/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift @@ -45,7 +45,7 @@ final class SessionEngine { self.sessionRequestsProvider = sessionRequestsProvider self.invalidRequestsSanitiser = invalidRequestsSanitiser - setupConnectionSubscriptions() + subscribeActiveSessions() setupRequestSubscriptions() setupResponseSubscriptions() setupUpdateSubscriptions() @@ -87,17 +87,11 @@ final class SessionEngine { // MARK: - Privates private extension SessionEngine { - - func setupConnectionSubscriptions() { - networkingInteractor.socketConnectionStatusPublisher - .sink { [unowned self] status in - guard status == .connected else { return } - let topics = sessionStore.getAll().map{$0.topic} - Task(priority: .high) { - try await networkingInteractor.batchSubscribe(topics: topics) - } - } - .store(in: &publishers) + func subscribeActiveSessions() { + let topics = sessionStore.getAll().map{$0.topic} + Task(priority: .background) { + try await networkingInteractor.batchSubscribe(topics: topics) + } } func setupRequestSubscriptions() { diff --git a/Tests/RelayerTests/AutomaticSocketConnectionHandlerTests.swift b/Tests/RelayerTests/AutomaticSocketConnectionHandlerTests.swift index 368d25da4..6b1809b35 100644 --- a/Tests/RelayerTests/AutomaticSocketConnectionHandlerTests.swift +++ b/Tests/RelayerTests/AutomaticSocketConnectionHandlerTests.swift @@ -8,37 +8,38 @@ final class AutomaticSocketConnectionHandlerTests: XCTestCase { var networkMonitor: NetworkMonitoringMock! var appStateObserver: AppStateObserverMock! var backgroundTaskRegistrar: BackgroundTaskRegistrarMock! + var subscriptionsTracker: SubscriptionsTrackerMock! + var socketStatusProviderMock: SocketStatusProviderMock! override func setUp() { webSocketSession = WebSocketMock() networkMonitor = NetworkMonitoringMock() appStateObserver = AppStateObserverMock() - let webSocket = WebSocketMock() let defaults = RuntimeKeyValueStorage() let logger = ConsoleLoggerMock() let keychainStorageMock = DispatcherKeychainStorageMock() let clientIdStorage = ClientIdStorage(defaults: defaults, keychain: keychainStorageMock, logger: logger) - - let socketAuthenticator = ClientIdAuthenticator(clientIdStorage: clientIdStorage) - let relayUrlFactory = RelayUrlFactory( - relayHost: "relay.walletconnect.com", - projectId: "1012db890cf3cfb0c1cdc929add657ba", - socketAuthenticator: socketAuthenticator - ) backgroundTaskRegistrar = BackgroundTaskRegistrarMock() + subscriptionsTracker = SubscriptionsTrackerMock() + + socketStatusProviderMock = SocketStatusProviderMock() + sut = AutomaticSocketConnectionHandler( socket: webSocketSession, networkMonitor: networkMonitor, appStateObserver: appStateObserver, backgroundTaskRegistrar: backgroundTaskRegistrar, - logger: ConsoleLoggerMock() + subscriptionsTracker: subscriptionsTracker, + logger: logger, + socketStatusProvider: socketStatusProviderMock ) } func testConnectsOnConnectionSatisfied() { webSocketSession.disconnect() + subscriptionsTracker.isSubscribedReturnValue = true // Simulate that there are active subscriptions XCTAssertFalse(webSocketSession.isConnected) networkMonitor.networkConnectionStatusPublisherSubject.send(.connected) XCTAssertTrue(webSocketSession.isConnected) @@ -53,11 +54,19 @@ final class AutomaticSocketConnectionHandlerTests: XCTestCase { } func testReconnectsOnEnterForeground() { + subscriptionsTracker.isSubscribedReturnValue = true // Simulate that there are active subscriptions webSocketSession.disconnect() appStateObserver.onWillEnterForeground?() XCTAssertTrue(webSocketSession.isConnected) } + func testReconnectsOnEnterForegroundWhenNoSubscriptions() { + subscriptionsTracker.isSubscribedReturnValue = false // Simulate no active subscriptions + webSocketSession.disconnect() + appStateObserver.onWillEnterForeground?() + XCTAssertFalse(webSocketSession.isConnected) // The connection should not be re-established + } + func testRegisterTaskOnEnterBackground() { XCTAssertNil(backgroundTaskRegistrar.completion) appStateObserver.onWillEnterBackground?() @@ -66,12 +75,15 @@ final class AutomaticSocketConnectionHandlerTests: XCTestCase { func testDisconnectOnEndBackgroundTask() { appStateObserver.onWillEnterBackground?() + webSocketSession.connect() XCTAssertTrue(webSocketSession.isConnected) backgroundTaskRegistrar.completion!() XCTAssertFalse(webSocketSession.isConnected) } func testReconnectOnDisconnectForeground() async { + subscriptionsTracker.isSubscribedReturnValue = true // Simulate that there are active subscriptions + webSocketSession.connect() appStateObserver.currentState = .foreground XCTAssertTrue(webSocketSession.isConnected) webSocketSession.disconnect() @@ -79,11 +91,153 @@ final class AutomaticSocketConnectionHandlerTests: XCTestCase { XCTAssertTrue(webSocketSession.isConnected) } + func testNotReconnectOnDisconnectForegroundWhenNoSubscriptions() async { + subscriptionsTracker.isSubscribedReturnValue = false // Simulate no active subscriptions + webSocketSession.connect() + appStateObserver.currentState = .foreground + XCTAssertTrue(webSocketSession.isConnected) + webSocketSession.disconnect() + await sut.handleDisconnection() + XCTAssertFalse(webSocketSession.isConnected) // The connection should not be re-established + } + func testReconnectOnDisconnectBackground() async { + subscriptionsTracker.isSubscribedReturnValue = true // Simulate that there are active subscriptions + webSocketSession.connect() + appStateObserver.currentState = .background + XCTAssertTrue(webSocketSession.isConnected) + webSocketSession.disconnect() + await sut.handleDisconnection() + XCTAssertFalse(webSocketSession.isConnected) + } + + func testNotReconnectOnDisconnectBackgroundWhenNoSubscriptions() async { + subscriptionsTracker.isSubscribedReturnValue = false // Simulate no active subscriptions + webSocketSession.connect() appStateObserver.currentState = .background XCTAssertTrue(webSocketSession.isConnected) webSocketSession.disconnect() await sut.handleDisconnection() + XCTAssertFalse(webSocketSession.isConnected) // The connection should not be re-established + } + + func testReconnectIfNeededWhenSubscribed() { + // Simulate that there are active subscriptions + subscriptionsTracker.isSubscribedReturnValue = true + + // Ensure socket is disconnected initially + webSocketSession.disconnect() + XCTAssertFalse(webSocketSession.isConnected) + + // Trigger reconnect logic + sut.reconnectIfNeeded() + + // Expect the socket to be connected since there are subscriptions + XCTAssertTrue(webSocketSession.isConnected) + } + + func testReconnectIfNeededWhenNotSubscribed() { + // Simulate that there are no active subscriptions + subscriptionsTracker.isSubscribedReturnValue = false + + // Ensure socket is disconnected initially + webSocketSession.disconnect() + XCTAssertFalse(webSocketSession.isConnected) + + // Trigger reconnect logic + sut.reconnectIfNeeded() + + // Expect the socket to remain disconnected since there are no subscriptions XCTAssertFalse(webSocketSession.isConnected) } + + func testReconnectsOnConnectionSatisfiedWhenSubscribed() { + // Simulate that there are active subscriptions + subscriptionsTracker.isSubscribedReturnValue = true + + // Ensure socket is disconnected initially + webSocketSession.disconnect() + XCTAssertFalse(webSocketSession.isConnected) + + // Simulate network connection becomes satisfied + networkMonitor.networkConnectionStatusPublisherSubject.send(.connected) + + // Expect the socket to reconnect since there are subscriptions + XCTAssertTrue(webSocketSession.isConnected) + } + + func testReconnectsOnEnterForegroundWhenSubscribed() { + // Simulate that there are active subscriptions + subscriptionsTracker.isSubscribedReturnValue = true + + // Ensure socket is disconnected initially + webSocketSession.disconnect() + XCTAssertFalse(webSocketSession.isConnected) + + // Simulate entering foreground + appStateObserver.onWillEnterForeground?() + + // Expect the socket to reconnect since there are subscriptions + XCTAssertTrue(webSocketSession.isConnected) + } + + func testSwitchesToPeriodicReconnectionAfterMaxImmediateAttempts() { + sut.connect() // Start connection process + + // Simulate immediate reconnection attempts + for _ in 0...sut.maxImmediateAttempts { + socketStatusProviderMock.simulateConnectionStatus(.disconnected) + } + + // Now we should be switching to periodic reconnection attempts + // Check reconnectionAttempts is set to maxImmediateAttempts + XCTAssertEqual(sut.reconnectionAttempts, sut.maxImmediateAttempts) + XCTAssertNotNil(sut.reconnectionTimer) // Periodic reconnection timer should be started + } + + func testPeriodicReconnectionStopsAfterSuccessfulConnection() { + sut.connect() // Start connection process + + // Simulate immediate reconnection attempts + for _ in 0...sut.maxImmediateAttempts { + socketStatusProviderMock.simulateConnectionStatus(.disconnected) + } + + // Check that periodic reconnection starts + XCTAssertNotNil(sut.reconnectionTimer) + + // Now simulate the connection being successful + socketStatusProviderMock.simulateConnectionStatus(.connected) + + // Periodic reconnection timer should stop + XCTAssertNil(sut.reconnectionTimer) + XCTAssertEqual(sut.reconnectionAttempts, 0) // Attempts should be reset + } + + func testPeriodicReconnectionAttempts() { + subscriptionsTracker.isSubscribedReturnValue = true // Simulate that there are active subscriptions + webSocketSession.disconnect() + sut.periodicReconnectionInterval = 0.0001 + sut.connect() // Start connection process + + // Simulate immediate reconnection attempts to switch to periodic + for _ in 0...sut.maxImmediateAttempts { + socketStatusProviderMock.simulateConnectionStatus(.disconnected) + } + + // Ensure we have switched to periodic reconnection + XCTAssertNotNil(sut.reconnectionTimer) + + // Simulate the periodic timer firing without waiting for real time + let expectation = XCTestExpectation(description: "Periodic reconnection attempt made") + sut.reconnectionTimer?.setEventHandler { + self.socketStatusProviderMock.simulateConnectionStatus(.connected) + expectation.fulfill() + } + + wait(for: [expectation], timeout: 1) + + // Check that the periodic reconnection attempt was made + XCTAssertTrue(webSocketSession.isConnected) // Assume that connection would have been attempted + } } diff --git a/Tests/RelayerTests/DispatcherTests.swift b/Tests/RelayerTests/DispatcherTests.swift index e8b0de168..3b41c4600 100644 --- a/Tests/RelayerTests/DispatcherTests.swift +++ b/Tests/RelayerTests/DispatcherTests.swift @@ -14,48 +14,13 @@ class DispatcherKeychainStorageMock: KeychainStorageProtocol { func deleteAll() throws {} } -class WebSocketMock: WebSocketConnecting { - var request: URLRequest = URLRequest(url: URL(string: "wss://relay.walletconnect.com")!) - - var onText: ((String) -> Void)? - var onConnect: (() -> Void)? - var onDisconnect: ((Error?) -> Void)? - var sendCallCount: Int = 0 - var isConnected: Bool = false - - func connect() { - isConnected = true - onConnect?() - } - - func disconnect() { - isConnected = false - onDisconnect?(nil) - } - - func write(string: String, completion: (() -> Void)?) { - sendCallCount+=1 - } -} - -class WebSocketFactoryMock: WebSocketFactory { - private let webSocket: WebSocketMock - - init(webSocket: WebSocketMock) { - self.webSocket = webSocket - } - - func create(with url: URL) -> WebSocketConnecting { - return webSocket - } -} - final class DispatcherTests: XCTestCase { var publishers = Set() var sut: Dispatcher! var webSocket: WebSocketMock! var networkMonitor: NetworkMonitoringMock! - + var socketStatusProviderMock: SocketStatusProviderMock! + override func setUp() { webSocket = WebSocketMock() let webSocketFactory = WebSocketFactoryMock(webSocket: webSocket) @@ -72,13 +37,15 @@ final class DispatcherTests: XCTestCase { socketAuthenticator: socketAuthenticator ) let socketConnectionHandler = ManualSocketConnectionHandler(socket: webSocket, logger: logger) + socketStatusProviderMock = SocketStatusProviderMock() sut = Dispatcher( socketFactory: webSocketFactory, relayUrlFactory: relayUrlFactory, networkMonitor: networkMonitor, socket: webSocket, logger: ConsoleLoggerMock(), - socketConnectionHandler: socketConnectionHandler + socketConnectionHandler: socketConnectionHandler, + socketStatusProvider: socketStatusProviderMock ) } @@ -88,16 +55,6 @@ final class DispatcherTests: XCTestCase { XCTAssertEqual(webSocket.sendCallCount, 1) } -// func testTextFramesSentAfterReconnectingSocket() { -// try! sut.disconnect(closeCode: .normalClosure) -// sut.send("1"){_ in} -// sut.send("2"){_ in} -// XCTAssertEqual(webSocketSession.sendCallCount, 0) -// try! sut.connect() -// socketConnectionObserver.onConnect?() -// XCTAssertEqual(webSocketSession.sendCallCount, 2) -// } - func testOnMessage() { let expectation = expectation(description: "on message") sut.onMessage = { message in @@ -114,7 +71,7 @@ final class DispatcherTests: XCTestCase { guard status == .connected else { return } expectation.fulfill() }.store(in: &publishers) - webSocket.onConnect?() + socketStatusProviderMock.simulateConnectionStatus(.connected) waitForExpectations(timeout: 0.001) } @@ -125,7 +82,7 @@ final class DispatcherTests: XCTestCase { guard status == .disconnected else { return } expectation.fulfill() }.store(in: &publishers) - webSocket.onDisconnect?(nil) + socketStatusProviderMock.simulateConnectionStatus(.disconnected) waitForExpectations(timeout: 0.001) } } diff --git a/Tests/RelayerTests/RelayClientTests.swift b/Tests/RelayerTests/RelayClientTests.swift index d767623e4..884c8047f 100644 --- a/Tests/RelayerTests/RelayClientTests.swift +++ b/Tests/RelayerTests/RelayClientTests.swift @@ -10,13 +10,15 @@ final class RelayClientTests: XCTestCase { var sut: RelayClient! var dispatcher: DispatcherMock! var publishers = Set() + var subscriptionsTracker: SubscriptionsTrackerMock! override func setUp() { dispatcher = DispatcherMock() let logger = ConsoleLogger() let clientIdStorage = ClientIdStorageMock() let rpcHistory = RPCHistoryFactory.createForRelay(keyValueStorage: RuntimeKeyValueStorage()) - sut = RelayClient(dispatcher: dispatcher, logger: logger, rpcHistory: rpcHistory, clientIdStorage: clientIdStorage) + subscriptionsTracker = SubscriptionsTrackerMock() + sut = RelayClient(dispatcher: dispatcher, logger: logger, rpcHistory: rpcHistory, clientIdStorage: clientIdStorage, subscriptionsTracker: subscriptionsTracker) } override func tearDown() { @@ -50,7 +52,7 @@ final class RelayClientTests: XCTestCase { func testUnsubscribeRequest() { let topic = String.randomTopic() - sut.subscriptions[topic] = "" + subscriptionsTracker.setSubscription(for: topic, id: "") sut.unsubscribe(topic: topic) { error in XCTAssertNil(error) } @@ -78,7 +80,7 @@ final class RelayClientTests: XCTestCase { func testSendOnUnsubscribe() { let topic = "123" - sut.subscriptions[topic] = "" + subscriptionsTracker.setSubscription(for: topic, id: "") sut.unsubscribe(topic: topic) {_ in } XCTAssertTrue(dispatcher.sent) }