diff --git a/net/basic_socket.cpp b/net/basic_socket.cpp index b5b2ec83..bd54f8fd 100644 --- a/net/basic_socket.cpp +++ b/net/basic_socket.cpp @@ -262,6 +262,33 @@ bool ISocketStream::skip_read(size_t count) { return true; } +ssize_t ISocketStream::recv_at_least(void* buf, size_t count, size_t least, int flags) { + size_t n = 0; + do { + ssize_t ret = this->recv(buf, count, flags); + if (ret < 0) return ret; + if (ret == 0) break; // EOF + if ((n += ret) >= least) break; + count -= ret; + } while (count); + return n; +} + +ssize_t ISocketStream::recv_at_least_mutable(struct iovec *iov, int iovcnt, + size_t least, int flags /*=0*/) { + size_t n = 0; + iovector_view v(iov, iovcnt); + do { + ssize_t ret = this->recv(v.iov, v.iovcnt, flags); + if (ret < 0) return ret; + if (ret == 0) break; // EOF + if ((n += ret) >= least) break; + auto r = v.extract_front(ret); + assert(r == ret); (void)r; + } while (v.iovcnt && v.iov->iov_len); + return n; +} + int do_get_name(int fd, Getter getter, EndPoint& addr) { sockaddr_storage storage; socklen_t len = storage.get_max_socklen(); diff --git a/net/socket.h b/net/socket.h index 15a89c6f..719e2436 100644 --- a/net/socket.h +++ b/net/socket.h @@ -217,6 +217,10 @@ namespace net { virtual ssize_t recv(void *buf, size_t count, int flags = 0) = 0; virtual ssize_t recv(const struct iovec *iov, int iovcnt, int flags = 0) = 0; + // recv at `least` bytes to buffer (`buf`, `count`) + ssize_t recv_at_least(void* buf, size_t count, size_t least, int flags = 0); + ssize_t recv_at_least_mutable(struct iovec *iov, int iovcnt, size_t least, int flags = 0); + // read count bytes and drop them // return true/false for success/failure bool skip_read(size_t count);