forked from folbricht/routedns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cache-redis.go
139 lines (122 loc) · 3.61 KB
/
cache-redis.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package rdns
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/miekg/dns"
"github.com/redis/go-redis/v9"
)
type redisBackend struct {
client *redis.Client
opt RedisBackendOptions
}
type RedisBackendOptions struct {
RedisOptions redis.Options
KeyPrefix string
}
var _ CacheBackend = (*redisBackend)(nil)
func NewRedisBackend(opt RedisBackendOptions) *redisBackend {
b := &redisBackend{
client: redis.NewClient(&opt.RedisOptions),
opt: opt,
}
return b
}
func (b *redisBackend) Store(query *dns.Msg, item *cacheAnswer) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := b.keyFromQuery(query)
value, err := json.Marshal(item)
if err != nil {
Log.WithError(err).Error("failed to marshal cache record")
return
}
if err := b.client.Set(ctx, key, value, time.Until(item.Expiry)).Err(); err != nil {
Log.WithError(err).Error("failed to write to redis")
}
}
func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := b.keyFromQuery(q)
value, err := b.client.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) { // Return a cache-miss if there's no such key
return nil, false, false
}
Log.WithError(err).Error("failed to read from redis")
return nil, false, false
}
var a *cacheAnswer
if err := json.Unmarshal([]byte(value), &a); err != nil {
Log.WithError(err).Error("failed to unmarshal cache record from redis")
return nil, false, false
}
answer := a.Msg
prefetchEligible := a.PrefetchEligible
answer.Id = q.Id
// Calculate the time the record spent in the cache. We need to
// subtract that from the TTL of each answer record.
age := uint32(time.Since(a.Timestamp).Seconds())
// Go through all the answers, NS, and Extra and adjust the TTL (subtract the time
// it's spent in the cache). If the record is too old, evict it from the cache
// and return a cache-miss. OPT records have a TTL of 0 and are ignored.
for _, rr := range [][]dns.RR{answer.Answer, answer.Ns, answer.Extra} {
for _, a := range rr {
if _, ok := a.(*dns.OPT); ok {
continue
}
h := a.Header()
if age >= h.Ttl {
return nil, false, false
}
h.Ttl -= age
}
}
return answer, prefetchEligible, true
}
func (b *redisBackend) Flush() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if _, err := b.client.Del(ctx, b.opt.KeyPrefix+"*").Result(); err != nil {
Log.WithError(err).Error("failed to delete keys in redis")
}
}
func (b *redisBackend) Size() int {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
size, err := b.client.DBSize(ctx).Result()
if err != nil {
Log.WithError(err).Error("failed to run dbsize command on redis")
}
return int(size)
}
func (b *redisBackend) Close() error {
return b.client.Close()
}
// Build a key string to be used in redis.
func (b *redisBackend) keyFromQuery(q *dns.Msg) string {
var key strings.Builder
key.WriteString(b.opt.KeyPrefix)
key.WriteString(q.Question[0].Name)
key.WriteByte(':')
key.WriteString(dns.Class(q.Question[0].Qclass).String())
key.WriteByte(':')
key.WriteString(dns.Type(q.Question[0].Qtype).String())
key.WriteByte(':')
edns0 := q.IsEdns0()
if edns0 != nil {
key.WriteString(fmt.Sprintf("%t", edns0.Do()))
key.WriteByte(':')
// See if we have a subnet option
for _, opt := range edns0.Option {
if subnet, ok := opt.(*dns.EDNS0_SUBNET); ok {
key.WriteString(subnet.Address.String())
}
}
}
return key.String()
}