-
Notifications
You must be signed in to change notification settings - Fork 3
/
client.go
194 lines (174 loc) · 4.79 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
package graceful
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"syscall"
)
// Errors used by Receive() function.
var (
ErrEmptyControlMessage = fmt.Errorf("empty control message")
ErrEmptyFileDescriptors = fmt.Errorf("empty file descriptors")
)
// ErrNotUnixConn is returned by a Client when not a *net.UnixConn is
// passed to its Receive* methods.
var ErrNotUnixConn = errors.New("not a unix connection")
// ReceiveCallback describes a function that will be called on each received
// descriptor while parsing control messages.
// Its first argument is a received file descriptor. Its second argument is an
// optional meta information represented by an io.Reader.
//
// If the callback returns non-nil error, then the function to which this
// callback was given exits immediately with that error.
//
// Note that meta reader is only valid until callback returns.
// If server does not provide additional information for descriptor, meta
// argument will be nil.
type ReceiveCallback func(fd int, meta io.Reader) error
// Receive dials to the "unix" network address addr and calls cb for each
// received descriptor from it until EOF.
func Receive(addr string, cb ReceiveCallback) error {
c := Client{}
return c.Receive(addr, cb)
}
// ReceiveFrom reads a single control message from the given connection conn
// and calls cb for each descriptor inside that message.
func ReceiveFrom(conn net.Conn, cb ReceiveCallback) error {
c := Client{}
return c.ReceiveFrom(conn, cb)
}
// ReceiveAllFrom reads all control messages from the given connection conn and
// calls cb for each descriptor inside those messages.
func ReceiveAllFrom(conn net.Conn, cb ReceiveCallback) error {
c := Client{}
return c.ReceiveAllFrom(conn, cb)
}
// Client contains logic of parsing control messages.
type Client struct {
// MsgBufferSize and OOBBufferSize defines an inner buffer sizes.
//
// MsgBufferSize defines size of the buffer for meta fields.
// If MsgBufferSize is zero, then the default size is used.
//
// OOBBufferSize defines size of the buffer for serialized descriptors.
// If OOBBufferSize is zero, then the default size is used.
//
// Note that client and server using this package MUST select the same
// buffer sizes. Another option is to use the global functions which use
// default sizes under the hood.
MsgBufferSize, OOBBufferSize int
once sync.Once
msg []byte
oob []byte
}
// Receive dials to the "unix" network address addr and calls cb for each
// received descriptor.
func (c *Client) Receive(addr string, cb ReceiveCallback) error {
conn, err := net.Dial("unix", addr)
if err != nil {
return err
}
defer conn.Close()
return c.ReceiveAllFrom(conn, cb)
}
// ReceiveFrom reads a single control message from the given connection conn
// and calls cb for each descriptor inside that message.
func (c *Client) ReceiveFrom(conn net.Conn, cb ReceiveCallback) error {
c.initOnce()
return receive(conn, c.msg, c.oob, cb)
}
// ReceiveAllFrom reads all control messages from the given connection conn and
// calls cb for each descriptor inside those messages.
func (c *Client) ReceiveAllFrom(conn net.Conn, cb ReceiveCallback) error {
c.initOnce()
for {
err := receive(conn, c.msg, c.oob, cb)
if err != nil {
if err == io.EOF {
err = nil
}
return err
}
}
return nil
}
func (c *Client) initOnce() {
c.once.Do(func() {
msgn := c.MsgBufferSize
if msgn == 0 {
msgn = msgDefaultBufferSize
}
oobn := c.OOBBufferSize
if oobn == 0 {
oobn = oobDefaultBufferSize
}
c.msg = make([]byte, msgn)
c.oob = make([]byte, oobn)
})
}
func receive(c net.Conn, msg, oob []byte, cb ReceiveCallback) error {
conn, ok := c.(*net.UnixConn)
if !ok {
return ErrNotUnixConn
}
msgn, oobn, _, _, err := conn.ReadMsgUnix(msg, oob)
if err != nil {
if isEOF(err) {
// Set err to io.EOF cause ReadMsgUnix returns net.OpError for
// EOF case.
err = io.EOF
}
return err
}
cmsg, err := syscall.ParseSocketControlMessage(oob[:oobn])
if err != nil {
return err
}
if len(cmsg) == 0 {
return ErrEmptyControlMessage
}
fds, err := syscall.ParseUnixRights(&cmsg[0])
if err != nil {
return err
}
if len(fds) == 0 {
return ErrEmptyFileDescriptors
}
var (
r = bytes.NewReader(msg[:msgn])
p = make([]byte, 4)
)
for _, fd := range fds {
// Read meta header.
if _, err := r.Read(p); err != nil {
return err
}
n := int64(binary.LittleEndian.Uint32(p))
var meta io.Reader
if n > 0 {
meta = io.LimitReader(r, n)
}
if err := cb(fd, meta); err != nil {
return err
}
if meta != nil {
// Ensure that all meta bytes was read.
_, err := io.Copy(ioutil.Discard, meta)
if err != nil {
return err
}
}
}
return nil
}
func isEOF(err error) bool {
if opErr, ok := err.(*net.OpError); ok {
err = opErr.Err
}
return err == io.EOF
}