Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(major): [sc-12609] Handle properly multiple services with the same name. #76

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions Sources/DistributedSystem/DiscoveryManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import PackageConcurrencyHelpers
final class DiscoveryManager {
private final class ProcessInfo {
var channel: (UInt32, Channel)?
var pendingServices = [(NodeService, DistributedSystem.ConnectionHandler)]()
var pendingServices = [(UUID, NodeService, DistributedSystem.ConnectionHandler)]()
}

private enum ServiceAddress: Equatable {
Expand Down Expand Up @@ -97,7 +97,7 @@ final class DiscoveryManager {
}
var discover: Bool
var addresses: [SocketAddress] = []
var services = [(ConsulServiceDiscovery.Instance, DistributedSystem.ChannelOrFactory)]()
var services = [(UUID, ConsulServiceDiscovery.Instance, DistributedSystem.ChannelOrFactory)]()

if cancellationToken.cancelled {
return (true, false, addresses, services)
Expand All @@ -107,21 +107,21 @@ final class DiscoveryManager {

var discoveryInfo = self.discoveries[serviceName]
if let discoveryInfo {
for serviceInfo in discoveryInfo.services.values {
for (serviceID, serviceInfo) in discoveryInfo.services {
if serviceFilter(serviceInfo.service) {
switch serviceInfo.address {
case let .local(factory):
services.append((serviceInfo.service, .factory(factory)))
services.append((serviceID, serviceInfo.service, .factory(factory)))
case let .remote(address):
if let processInfo = self.processes[address] {
if let (channelID, channel) = processInfo.channel {
services.append((serviceInfo.service, .channel(channelID, channel)))
services.append((serviceID, serviceInfo.service, .channel(channelID, channel)))
} else {
processInfo.pendingServices.append((serviceInfo.service, connectionHandler))
processInfo.pendingServices.append((serviceID, serviceInfo.service, connectionHandler))
}
} else {
let processInfo = ProcessInfo()
processInfo.pendingServices.append((serviceInfo.service, connectionHandler))
processInfo.pendingServices.append((serviceID, serviceInfo.service, connectionHandler))
self.processes[address] = processInfo
addresses.append(address)
}
Expand All @@ -145,8 +145,8 @@ final class DiscoveryManager {
return .cancelled
} else {
logger.debug("discoverService[\(serviceName)]: \(discover) \(addresses) \(services), cancellation token \(cancellationToken.ptr)")
for (service, addr) in services {
connectionHandler(service, addr)
for (serviceID, service, addr) in services {
connectionHandler(serviceID, service, addr)
}
return .started(discover, addresses)
}
Expand Down Expand Up @@ -175,14 +175,14 @@ final class DiscoveryManager {
}
}

func factoryFor(_ serviceName: String) -> DistributedSystem.ServiceFactory? {
func factoryFor(_ serviceName: String, _ serviceID: UUID) -> DistributedSystem.ServiceFactory? {
lock.withLock {
guard let discoveryInfo = self.discoveries[serviceName] else {
return nil
}

for entry in discoveryInfo.services {
if case let .local(factory) = entry.value.address {
if let serviceInfo = discoveryInfo.services[serviceID] {
if case let .local(factory) = serviceInfo.address {
return factory
}
}
Expand Down Expand Up @@ -226,7 +226,7 @@ final class DiscoveryManager {
}

for (service, connectionHandler) in services {
_ = connectionHandler(service, .factory(factory))
_ = connectionHandler(serviceID, service, .factory(factory))
}

return updateHealthStatus
Expand Down Expand Up @@ -259,10 +259,10 @@ final class DiscoveryManager {
return (false, nil)
}
} else {
var pendingServices = [(NodeService, DistributedSystem.ConnectionHandler)]()
var pendingServices = [(UUID, NodeService, DistributedSystem.ConnectionHandler)]()
for filterInfo in discoveryInfo.filters.values {
if filterInfo.filter(service) {
pendingServices.append((service, filterInfo.connectionHandler))
pendingServices.append((serviceID, service, filterInfo.connectionHandler))
}
}
if pendingServices.isEmpty {
Expand All @@ -279,7 +279,7 @@ final class DiscoveryManager {

if let process {
for connectionHandler in process.connectionHandlers {
connectionHandler(service, .channel(process.channel.0, process.channel.1))
connectionHandler(serviceID, service, .channel(process.channel.0, process.channel.1))
}
}

Expand All @@ -293,14 +293,14 @@ final class DiscoveryManager {
}
processInfo.channel = (channelID, channel)

var pendingServices = [(NodeService, DistributedSystem.ConnectionHandler)]()
var pendingServices = [(UUID, NodeService, DistributedSystem.ConnectionHandler)]()
swap(&processInfo.pendingServices, &pendingServices)

return pendingServices
}

for (service, connectionHandler) in services {
connectionHandler(service, .channel(channelID, channel))
for (serviceID, service, connectionHandler) in services {
connectionHandler(serviceID, service, .channel(channelID, channel))
}
}

Expand Down
29 changes: 20 additions & 9 deletions Sources/DistributedSystem/DistributedSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
static let pingInterval = TimeAmount.seconds(2)
static let serviceDiscoveryTimeout = TimeAmount.seconds(5)

static let protocolVersionMajor: UInt16 = 3
static let protocolVersionMajor: UInt16 = 4
static let protocolVersionMinor: UInt16 = 0

enum SessionMessage: UInt16 {
Expand Down Expand Up @@ -255,7 +255,7 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
case factory(ServiceFactory)
}

typealias ConnectionHandler = (ConsulServiceDiscovery.Instance, ChannelOrFactory) -> Void
typealias ConnectionHandler = (UUID, ConsulServiceDiscovery.Instance, ChannelOrFactory) -> Void

@TaskLocal
private static var actorID: ActorID? // supposed to be private, but need to make it internal for tests
Expand Down Expand Up @@ -311,17 +311,23 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
}
}

private func sendCreateService(_ serviceName: String, _ instanceID: UInt32, to channel: Channel) {
private func sendCreateService(_ serviceName: String, _ serviceID: UUID, _ instanceID: UInt32, to channel: Channel) {
let payloadSize =
MemoryLayout<SessionMessage.RawValue>.size
+ ULEB128.size(UInt(serviceName.count))
+ serviceName.count
+ MemoryLayout<uuid_t>.size
+ ULEB128.size(instanceID)
var buffer = ByteBufferAllocator().buffer(capacity: MemoryLayout<UInt32>.size + payloadSize)
buffer.writeInteger(UInt32(payloadSize))
buffer.writeInteger(SessionMessage.createServiceInstance.rawValue)
buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { ptr in ULEB128.encode(UInt(serviceName.count), to: ptr.baseAddress!) }
buffer.writeString(serviceName)
buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { ptr in
var uuid = serviceID.uuid
withUnsafeBytes(of: &uuid) { ptr.copyMemory(from: $0) }
return MemoryLayout<uuid_t>.size
}
buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { ptr in ULEB128.encode(instanceID, to: ptr.baseAddress!) }
logger.debug("\(channel.addressDescription): send create \(serviceName) \(EndpointIdentifier.instanceIdentifierDescription(instanceID))")
_ = channel.writeAndFlush(buffer, promise: nil)
Expand Down Expand Up @@ -476,8 +482,7 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
where S.ID == EndpointIdentifier, S.ActorSystem == DistributedSystem {
let serviceName = S.serviceName
logger.debug("connectTo: \(serviceName)")

let connectionHandler = { (service: ConsulServiceDiscovery.Instance, channelOrFactory: ChannelOrFactory) -> Void in
let connectionHandler = { (serviceID: UUID, service: ConsulServiceDiscovery.Instance, channelOrFactory: ChannelOrFactory) -> Void in
let serviceEndpointID = {
switch channelOrFactory {
case let .channel(channelID, channel):
Expand All @@ -487,7 +492,7 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
self.actors[serviceEndpointID] = .remoteService(.init())
return serviceEndpointID
}
self.sendCreateService(serviceName, serviceEndpointID.instanceID, to: channel)
self.sendCreateService(serviceName, serviceID, serviceEndpointID.instanceID, to: channel)
return serviceEndpointID
case let .factory(factory):
let serviceEndpointID = self.makeServiceEndpoint(0)
Expand Down Expand Up @@ -1175,8 +1180,14 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
guard let serviceName = buffer.readString(length: Int(serviceNameLength)) else {
throw DecodeError.error("failed to decode service name")
}
let serviceID = buffer.withUnsafeReadableBytes { ptr in
var uuid = uuid_t(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
withUnsafeMutableBytes(of: &uuid) { _ = ptr.copyBytes(to: $0) }
return UUID(uuid: uuid)
}
buffer.moveReaderIndex(forwardBy: MemoryLayout<uuid_t>.size)
let instanceID = try Self.readULEB128(from: &buffer, as: EndpointIdentifier.InstanceIdentifier.self)
createService(serviceName, instanceID, for: channelID, channel)
createService(serviceName, serviceID, instanceID, for: channelID, channel)
case .invocationEnvelope:
let instanceID = try Self.readULEB128(from: &buffer, as: EndpointIdentifier.InstanceIdentifier.self)
let endpointID = EndpointIdentifier(channelID, instanceID)
Expand Down Expand Up @@ -1218,8 +1229,8 @@ public class DistributedSystem: DistributedActorSystem, @unchecked Sendable {
}
}

private func createService(_ serviceName: String, _ instanceID: EndpointIdentifier.InstanceIdentifier, for channelID: UInt32, _ channel: Channel) {
let serviceFactory = discoveryManager.factoryFor(serviceName)
private func createService(_ serviceName: String, _ serviceID: UUID, _ instanceID: EndpointIdentifier.InstanceIdentifier, for channelID: UInt32, _ channel: Channel) {
let serviceFactory = discoveryManager.factoryFor(serviceName, serviceID)
guard let serviceFactory else {
logger.error("\(channel.addressDescription): service \(serviceName) for \(instanceID) not registered")
return
Expand Down
100 changes: 100 additions & 0 deletions Tests/DistributedSystemTests/DistributedSystemTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ final class DistributedSystemTests: XCTestCase {
let moduleID = DistributedSystem.ModuleIdentifier(UInt64(processInfo.processIdentifier))
let actorSystem = DistributedSystemServer(name: systemName)
try await actorSystem.start()

try await actorSystem.addService(ofType: TestServiceEndpoint.self, toModule: moduleID) { actorSystem in
let service = ServiceWithLeakCheckImpl(flags)
let serviceEndpoint = try TestServiceEndpoint(service, in: actorSystem)
Expand Down Expand Up @@ -1047,4 +1048,103 @@ final class DistributedSystemTests: XCTestCase {

try await Task.sleep(for: .seconds(180))
}

func testMultipleServices() async throws {
class TestServiceImpl: TestableService {
let id: Int
var clientEndpoint: TestClientEndpoint?

init(_ id: Int) {
self.id = id
}

func openStream(byRequest request: TestMessages.OpenRequest) async {
var reply = _StreamOpenedStruct(requestIdentifier: request.requestIdentifier)
reply.streamIdentifier = StreamIdentifier(id)
do {
try await clientEndpoint?.streamOpened(StreamOpened(reply))
} catch {
logger.error("\(error)")
}
}

func getMonster() async -> TestMessages.Monster {
fatalError("should not be called")
}

func doNothing() async {}
func handleMonsters(_ monsters: [TestMessages.Monster]) async {}
func handleConnectionState(_ state: ConnectionState) async {}
}

class TestClientImpl: TestableClient {
var stream: AsyncStream<StreamOpened>
var continuation: AsyncStream<StreamOpened>.Continuation

init() {
(stream, continuation) = AsyncStream<StreamOpened>.makeStream()
}

func streamOpened(_ reply: TestMessages.StreamOpened) async {
continuation.yield(reply)
}

func snapshotDone(for: TestMessages.Stream) async {}
func handleMonster(_ monster: TestMessages.Monster, for stream: TestMessages.Stream) async {}
func handleConnectionState(_ state: ConnectionState) async {}
}

let processInfo = ProcessInfo.processInfo
let systemName = "\(processInfo.hostName)-ts-\(processInfo.processIdentifier)-\(#line)"

let moduleID = DistributedSystem.ModuleIdentifier(1)
let serverSystem = DistributedSystemServer(name: systemName, compressionMode: .disabled)
try await serverSystem.start()

try await serverSystem.addService(ofType: TestServiceEndpoint.self, toModule: moduleID, metadata: ["opt": "1"]) { actorSystem in
let service = TestServiceImpl(1)
let serviceEndpoint = try TestServiceEndpoint(service, in: actorSystem)
let clientEndpointID = serviceEndpoint.id.makeClientEndpoint()
service.clientEndpoint = try TestClientEndpoint.resolve(id: clientEndpointID, using: actorSystem)
return serviceEndpoint
}

try await serverSystem.addService(ofType: TestServiceEndpoint.self, toModule: moduleID, metadata: ["opt": "2"]) { actorSystem in
let service = TestServiceImpl(2)
let serviceEndpoint = try TestServiceEndpoint(service, in: actorSystem)
let clientEndpointID = serviceEndpoint.id.makeClientEndpoint()
service.clientEndpoint = try TestClientEndpoint.resolve(id: clientEndpointID, using: actorSystem)
return serviceEndpoint
}

let clientSystem = DistributedSystem(name: systemName, compressionMode: .disabled)
try clientSystem.start()

let client = TestClientImpl()
let serviceEndpoint = try await clientSystem.connectToService(
TestServiceEndpoint.self,
withFilter: {
if let serviceMeta = $0.serviceMeta {
if let opt = serviceMeta["opt"], opt == "2" {
return true
}
}
return false
},
clientFactory: { actorSystem in
TestClientEndpoint(client, in: actorSystem)
}
)

let openRequest = _OpenRequestStruct(requestIdentifier: 2)
try await serviceEndpoint.openStream(byRequest: OpenRequest(openRequest))

for await reply in client.stream {
XCTAssertEqual(reply.streamIdentifier, 2)
break
}

clientSystem.stop()
serverSystem.stop()
}
}
Loading