diff --git a/dns.go b/dns.go index 53c82df..3439165 100644 --- a/dns.go +++ b/dns.go @@ -9,11 +9,19 @@ import ( // Extracted from source of truth for multicodec codes: https://github.com/multiformats/multicodec const ( + P_DNS = 0x0035 P_DNS4 = 0x0036 P_DNS6 = 0x0037 P_DNSADDR = 0x0038 ) +var DnsProtocol = ma.Protocol{ + Code: P_DNS, + Size: ma.LengthPrefixedVarSize, + Name: "dns", + VCode: ma.CodeToVarint(P_DNS), + Transcoder: DnsTranscoder, +} var Dns4Protocol = ma.Protocol{ Code: P_DNS4, Size: ma.LengthPrefixedVarSize, @@ -37,7 +45,11 @@ var DnsaddrProtocol = ma.Protocol{ } func init() { - err := ma.AddProtocol(Dns4Protocol) + err := ma.AddProtocol(DnsProtocol) + if err != nil { + panic(fmt.Errorf("error registering dns protocol: %s", err)) + } + err = ma.AddProtocol(Dns4Protocol) if err != nil { panic(fmt.Errorf("error registering dns4 protocol: %s", err)) } diff --git a/resolve.go b/resolve.go index a85490c..e2f5d2e 100644 --- a/resolve.go +++ b/resolve.go @@ -8,7 +8,7 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -var ResolvableProtocols = []ma.Protocol{DnsaddrProtocol, Dns4Protocol, Dns6Protocol} +var ResolvableProtocols = []ma.Protocol{DnsaddrProtocol, Dns4Protocol, Dns6Protocol, DnsProtocol} var DefaultResolver = &Resolver{Backend: net.DefaultResolver} const dnsaddrTXTPrefix = "dnsaddr=" @@ -48,7 +48,7 @@ func (r *MockBackend) LookupTXT(ctx context.Context, name string) ([]string, err func Matches(maddr ma.Multiaddr) (matches bool) { ma.ForEach(maddr, func(c ma.Component) bool { switch c.Protocol().Code { - case Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: + case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: matches = true } return !matches @@ -66,7 +66,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia var keep ma.Multiaddr keep, maddr = ma.SplitFunc(maddr, func(c ma.Component) bool { switch c.Protocol().Code { - case Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: + case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: return true default: return false @@ -97,8 +97,9 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia var resolved []ma.Multiaddr switch proto.Code { - case Dns4Protocol.Code, Dns6Protocol.Code: - v4 := proto.Code == Dns4Protocol.Code + case Dns4Protocol.Code, Dns6Protocol.Code, DnsProtocol.Code: + v4only := proto.Code == Dns4Protocol.Code + v6only := proto.Code == Dns6Protocol.Code // XXX: Unfortunately, go does a pretty terrible job of // differentiating between IPv6 and IPv4. A v4-in-v6 @@ -115,16 +116,16 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia err error ) ip4 := r.IP.To4() - if v4 { - if ip4 == nil { + if ip4 == nil { + if v4only { continue } - rmaddr, err = ma.NewMultiaddr("/ip4/" + ip4.String()) + rmaddr, err = ma.NewMultiaddr("/ip6/" + r.IP.String()) } else { - if ip4 != nil { + if v6only { continue } - rmaddr, err = ma.NewMultiaddr("/ip6/" + r.IP.String()) + rmaddr, err = ma.NewMultiaddr("/ip4/" + ip4.String()) } if err != nil { return nil, err diff --git a/resolve_test.go b/resolve_test.go index 12066df..1334611 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -48,6 +48,9 @@ func TestMatches(t *testing.T) { // need to depend on the circuit package to parse it. t.Fatalf("expected match, didn't: /tcp/1234/dns6/example.com") } + if !Matches(ma.StringCast("/dns/example.com")) { + t.Fatalf("expected match, didn't: /dns/example.com") + } if !Matches(ma.StringCast("/dns4/example.com")) { t.Fatalf("expected match, didn't: /dns4/example.com") } @@ -81,6 +84,16 @@ func TestSimpleIPResolve(t *testing.T) { if len(addrs6) != 2 || !addrs6[0].Equal(ip6ma) || addrs6[0].Equal(ip6mb) { t.Fatalf("expected [%s %s], got %+v", ip6ma, ip6mb, addrs6) } + + addrs, err := resolver.Resolve(ctx, ma.StringCast("/dns/example.com")) + if err != nil { + t.Error(err) + } + for i, expected := range []ma.Multiaddr{ip4ma, ip4mb, ip6ma, ip6mb} { + if !expected.Equal(addrs[i]) { + t.Fatalf("%d: expected %s, got %s", i, expected, addrs[i]) + } + } } func TestResolveMultiple(t *testing.T) {