diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 455670bfc7..4589d82e8d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -34,7 +34,7 @@ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Features, InitFeature, Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee} +import fr.acinq.eclair.{EncodedNodeId, Features, InitFeature, Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee} import java.util.UUID import java.util.concurrent.TimeUnit @@ -142,7 +142,8 @@ class ChannelRelay private(nodeParams: NodeParams, } private val (requestedShortChannelId_opt, walletNodeId_opt) = r.payload.outgoing match { - case Left(walletNodeId) => (None, Some(walletNodeId)) + case Left(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)) => (None, Some(walletNodeId)) + case Left(_) => (None, None) case Right(shortChannelId) => (Some(shortChannelId), None) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index 8c635df706..bf94efef55 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -71,7 +71,7 @@ object ChannelRelayer { case Relay(channelRelayPacket, originNode) => val relayId = UUID.randomUUID() val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match { - case Left(walletNodeId) => Some(walletNodeId) + case Left(outgoingNodeId) => Some(outgoingNodeId.publicKey) case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match { case Some(channelId) => channels.get(channelId).map(_.nextNodeId) case None => None diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala index 6b1e8e361c..663c92e543 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala @@ -112,8 +112,7 @@ private class BlindedPathsResolver(nodeParams: NodeParams, ) paymentRelayData.outgoing match { case Left(outgoingNodeId) => - // The next node seems to be a wallet node directly connected to us. - validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextPathKey, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + validateRelay(outgoingNodeId, nextPaymentInfo, paymentRelayData, nextPathKey, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) case Right(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId) waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextPathKey, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 6e22a25948..7188902b02 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -228,7 +228,7 @@ object PaymentOnion { sealed trait ChannelRelay extends IntermediatePayload { // @formatter:off /** The outgoing channel, or the nodeId of one of our peers. */ - def outgoing: Either[PublicKey, ShortChannelId] + def outgoing: Either[EncodedNodeId.WithPublicKey, ShortChannelId] def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 460353f234..0ee4626a1e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -98,10 +98,9 @@ object BlindedRouteData { } case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) { - // This is usually a channel, unless the next node is a mobile wallet connected to our node. - val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match { + val outgoing: Either[EncodedNodeId.WithPublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match { case Some(r) => Right(r.shortChannelId) - case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey].publicKey) + case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey]) } val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get @@ -114,7 +113,6 @@ object BlindedRouteData { } def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = { - // Note that the BOLTs require using an OutgoingChannelId, but we optionally support using a node_id. if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey]) return Left(ForbiddenTlv(UInt64(4))) if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index 8e98bdcff1..0ec96b7b15 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -128,7 +128,8 @@ class PaymentOnionSpec extends AnyFunSuite { RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), ) val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey) - assert(payload.outgoing == Left(nextNodeId)) + val Left(nodeId) = payload.outgoing + assert(nodeId.publicKey == nextNodeId) assert(payload.amountToForward(10_000 msat) == 9990.msat) assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) assert(payload.paymentRelayData.allowedFeatures.isEmpty)