diff --git a/conn_linux.go b/conn_linux.go index 4af18c9..ed105b0 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -121,28 +121,20 @@ func (c *conn) Send(m Message) error { // Receive receives one or more Messages from netlink. func (c *conn) Receive() ([]Message, error) { - b := make([]byte, os.Getpagesize()) - for { - // Peek at the buffer to see how many bytes are available. - // - // TODO(mdlayher): deal with OOB message data if available, such as - // when PacketInfo ConnOption is true. - n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, unix.MSG_PEEK) - if err != nil { - return nil, err - } - - // Break when we can read all messages - if n < len(b) { - break - } - - // Double in size if not enough bytes - b = make([]byte, len(b)*2) + b := make([]byte, unix.SizeofNlMsghdr) + // Peek at the buffer to see how many bytes are available. + // + // TODO(mdlayher): deal with OOB message data if available, such as + // when PacketInfo ConnOption is true. + n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, unix.MSG_PEEK|unix.MSG_TRUNC) + if err != nil { + return nil, err } + // Resize buffer to the expected size. + b = make([]byte, nlmsgAlign(n)) // Read out all available messages - n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, 0) + n, _, _, _, err = c.s.Recvmsg(context.Background(), b, nil, 0) if err != nil { return nil, err }