diff --git a/resolve.go b/resolve.go index 6808ca1..a85490c 100644 --- a/resolve.go +++ b/resolve.go @@ -2,7 +2,6 @@ package madns import ( "context" - "fmt" "net" "strings" @@ -12,6 +11,8 @@ import ( var ResolvableProtocols = []ma.Protocol{DnsaddrProtocol, Dns4Protocol, Dns6Protocol} var DefaultResolver = &Resolver{Backend: net.DefaultResolver} +const dnsaddrTXTPrefix = "dnsaddr=" + type backend interface { LookupIPAddr(context.Context, string) ([]net.IPAddr, error) LookupTXT(context.Context, string) ([]string, error) @@ -44,19 +45,15 @@ func (r *MockBackend) LookupTXT(ctx context.Context, name string) ([]string, err } } -func Matches(maddr ma.Multiaddr) bool { - protos := maddr.Protocols() - if len(protos) == 0 { - return false - } - - for _, p := range ResolvableProtocols { - if protos[0].Code == p.Code { - return true +func Matches(maddr ma.Multiaddr) (matches bool) { + ma.ForEach(maddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: + matches = true } - } - - return false + return !matches + }) + return matches } func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { @@ -64,121 +61,162 @@ func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { } func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - if !Matches(maddr) { - return []ma.Multiaddr{maddr}, nil - } - - protos := maddr.Protocols() - if protos[0].Code == Dns4Protocol.Code { - return r.resolveDns4(ctx, maddr) - } - if protos[0].Code == Dns6Protocol.Code { - return r.resolveDns6(ctx, maddr) - } - if protos[0].Code == DnsaddrProtocol.Code { - return r.resolveDnsaddr(ctx, maddr) - } - - panic("unreachable") -} - -func (r *Resolver) resolveDns4(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - value, err := maddr.ValueForProtocol(Dns4Protocol.Code) - if err != nil { - return nil, fmt.Errorf("error resolving %s: %s", maddr.String(), err) - } - - encap := ma.Split(maddr)[1:] - - result := []ma.Multiaddr{} - records, err := r.Backend.LookupIPAddr(ctx, value) - if err != nil { - return result, err - } - - for _, r := range records { - ip4 := r.IP.To4() - if ip4 == nil { - continue - } - ip4maddr, err := ma.NewMultiaddr("/ip4/" + ip4.String()) - if err != nil { - return result, err + var results []ma.Multiaddr + for i := 0; maddr != nil; i++ { + var keep ma.Multiaddr + keep, maddr = ma.SplitFunc(maddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: + return true + default: + return false + } + }) + + // Append the part we're keeping. + if keep != nil { + if results == nil { + results = append(results, keep) + } else { + for i, r := range results { + results[i] = r.Encapsulate(keep) + } + } } - parts := append([]ma.Multiaddr{ip4maddr}, encap...) - result = append(result, ma.Join(parts...)) - } - return result, nil -} -func (r *Resolver) resolveDns6(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - value, err := maddr.ValueForProtocol(Dns6Protocol.Code) - if err != nil { - return nil, fmt.Errorf("error resolving %s: %s", maddr.String(), err) - } - - encap := ma.Split(maddr)[1:] - - result := []ma.Multiaddr{} - records, err := r.Backend.LookupIPAddr(ctx, value) - if err != nil { - return result, err - } + // Check to see if we're done. + if maddr == nil { + break + } - for _, r := range records { - if r.IP.To4() != nil { - continue + var resolve *ma.Component + resolve, maddr = ma.SplitFirst(maddr) + + proto := resolve.Protocol() + value := resolve.Value() + + var resolved []ma.Multiaddr + switch proto.Code { + case Dns4Protocol.Code, Dns6Protocol.Code: + v4 := proto.Code == Dns4Protocol.Code + + // XXX: Unfortunately, go does a pretty terrible job of + // differentiating between IPv6 and IPv4. A v4-in-v6 + // AAAA record will _look_ like an A record to us and + // there's nothing we can do about that. + records, err := r.Backend.LookupIPAddr(ctx, value) + if err != nil { + return nil, err + } + + for _, r := range records { + var ( + rmaddr ma.Multiaddr + err error + ) + ip4 := r.IP.To4() + if v4 { + if ip4 == nil { + continue + } + rmaddr, err = ma.NewMultiaddr("/ip4/" + ip4.String()) + } else { + if ip4 != nil { + continue + } + rmaddr, err = ma.NewMultiaddr("/ip6/" + r.IP.String()) + } + if err != nil { + return nil, err + } + resolved = append(resolved, rmaddr) + } + case DnsaddrProtocol.Code: + records, err := r.Backend.LookupTXT(ctx, "_dnsaddr."+value) + if err != nil { + return nil, err + } + + length := 0 + if maddr != nil { + length = addrLen(maddr) + } + for _, r := range records { + if !strings.HasPrefix(r, dnsaddrTXTPrefix) { + continue + } + rmaddr, err := ma.NewMultiaddr(r[len(dnsaddrTXTPrefix):]) + if err != nil { + // discard multiaddrs we don't understand. + // XXX: Is this right? + continue + } + + if maddr != nil { + rmlen := addrLen(rmaddr) + if rmlen < length { + // not long enough. + continue + } + + // Matches everything after the /dnsaddr/... with the end of the + // dnsaddr record: + // + // v----------rmlen-----------------v + // /ip4/1.2.3.4/tcp/1234/p2p/QmFoobar + // /p2p/QmFoobar + // ^--(rmlen - length)--^---length--^ + if !maddr.Equal(offset(rmaddr, rmlen-length)) { + continue + } + } + + resolved = append(resolved, rmaddr) + } + + // consumes the rest of the multiaddr as part of the "match" process. + maddr = nil + default: + panic("unreachable") } - ip6maddr, err := ma.NewMultiaddr("/ip6/" + r.IP.To16().String()) - if err != nil { - return result, err + + if len(resolved) == 0 { + return nil, nil + } else if len(results) == 0 { + results = resolved + } else { + results = cross(results, resolved) } - parts := append([]ma.Multiaddr{ip6maddr}, encap...) - result = append(result, ma.Join(parts...)) } - return result, nil + return results, nil } -func (r *Resolver) resolveDnsaddr(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - value, err := maddr.ValueForProtocol(DnsaddrProtocol.Code) - if err != nil { - return nil, fmt.Errorf("error resolving %s: %s", maddr.String(), err) - } - - trailer := ma.Split(maddr)[1:] - - result := []ma.Multiaddr{} - records, err := r.Backend.LookupTXT(ctx, "_dnsaddr."+value) - if err != nil { - return result, err - } - - for _, r := range records { - rv := strings.Split(r, "dnsaddr=") - if len(rv) != 2 { - continue - } - - rmaddr, err := ma.NewMultiaddr(rv[1]) - if err != nil { - return result, err - } +func addrLen(maddr ma.Multiaddr) int { + length := 0 + ma.ForEach(maddr, func(_ ma.Component) bool { + length++ + return true + }) + return length +} - if matchDnsaddr(rmaddr, trailer) { - result = append(result, rmaddr) +func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr { + _, after := ma.SplitFunc(maddr, func(c ma.Component) bool { + if offset == 0 { + return true } - } - return result, nil + offset-- + return false + }) + return after } -// XXX probably insecure -func matchDnsaddr(maddr ma.Multiaddr, trailer []ma.Multiaddr) bool { - parts := ma.Split(maddr) - if len(trailer) > len(parts) { - return false - } - if ma.Join(parts[len(parts)-len(trailer):]...).Equal(ma.Join(trailer...)) { - return true +func cross(a, b []ma.Multiaddr) []ma.Multiaddr { + res := make([]ma.Multiaddr, 0, len(a)*len(b)) + for _, x := range a { + for _, y := range b { + res = append(res, x.Encapsulate(y)) + } } - return false + return res } diff --git a/resolve_test.go b/resolve_test.go index e0d1baa..12066df 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -35,7 +35,7 @@ func makeResolver() *Resolver { }, TXT: map[string][]string{ "_dnsaddr.example.com": []string{txta, txtb}, - "_dnsaddr.matching.com": []string{txtc, txtd, txte}, + "_dnsaddr.matching.com": []string{txtc, txtd, txte, "not a dnsaddr", "dnsaddr=/foobar"}, }, } resolver := &Resolver{Backend: mock} @@ -43,6 +43,11 @@ func makeResolver() *Resolver { } func TestMatches(t *testing.T) { + if !Matches(ma.StringCast("/tcp/1234/dns6/example.com")) { + // Pretend this is a p2p-circuit address. Unfortunately, we'd + // need to depend on the circuit package to parse it. + t.Fatalf("expected match, didn't: /tcp/1234/dns6/example.com") + } if !Matches(ma.StringCast("/dns4/example.com")) { t.Fatalf("expected match, didn't: /dns4/example.com") } @@ -78,6 +83,63 @@ func TestSimpleIPResolve(t *testing.T) { } } +func TestResolveMultiple(t *testing.T) { + ctx := context.Background() + resolver := makeResolver() + + addrs, err := resolver.Resolve(ctx, ma.StringCast("/dns4/example.com/quic/dns6/example.com")) + if err != nil { + t.Error(err) + } + for i, x := range []ma.Multiaddr{ip4ma, ip4mb} { + for j, y := range []ma.Multiaddr{ip6ma, ip6mb} { + expected := ma.Join(x, ma.StringCast("/quic"), y) + actual := addrs[i*2+j] + if !expected.Equal(actual) { + t.Fatalf("expected %s, got %s", expected, actual) + } + } + } +} + +func TestResolveMultipleAdjacent(t *testing.T) { + ctx := context.Background() + resolver := makeResolver() + + addrs, err := resolver.Resolve(ctx, ma.StringCast("/dns4/example.com/dns6/example.com")) + if err != nil { + t.Error(err) + } + for i, x := range []ma.Multiaddr{ip4ma, ip4mb} { + for j, y := range []ma.Multiaddr{ip6ma, ip6mb} { + expected := ma.Join(x, y) + actual := addrs[i*2+j] + if !expected.Equal(actual) { + t.Fatalf("expected %s, got %s", expected, actual) + } + } + } +} + +func TestResolveMultipleSandwitch(t *testing.T) { + ctx := context.Background() + resolver := makeResolver() + + addrs, err := resolver.Resolve(ctx, ma.StringCast("/quic/dns4/example.com/dns6/example.com/http")) + if err != nil { + t.Error(err) + } + for i, x := range []ma.Multiaddr{ip4ma, ip4mb} { + for j, y := range []ma.Multiaddr{ip6ma, ip6mb} { + expected := ma.Join(ma.StringCast("/quic"), x, y, ma.StringCast("/http")) + actual := addrs[i*2+j] + if !expected.Equal(actual) { + t.Fatalf("expected %s, got %s", expected, actual) + } + } + } +} + func TestSimpleTXTResolve(t *testing.T) { ctx := context.Background() resolver := makeResolver()