diff --git a/test/py/libvfio_user.py b/test/py/libvfio_user.py index e75f4b90..b62c5aca 100644 --- a/test/py/libvfio_user.py +++ b/test/py/libvfio_user.py @@ -474,6 +474,14 @@ class vfio_user_dma_unmap(Structure): ] +class vfio_user_dma_region_access(Structure): + _pack_ = 1 + _fields_ = [ + ("addr", c.c_uint64), + ("count", c.c_uint64), + ] + + class vfu_dma_info_t(Structure): _fields_ = [ ("iova", iovec_t), @@ -632,6 +640,10 @@ class vfio_user_migration_info(Structure): c.POINTER(iovec_t), c.c_size_t, c.c_int) lib.vfu_sgl_put.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.POINTER(iovec_t), c.c_size_t) +lib.vfu_sgl_read.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t, + c.c_void_p) +lib.vfu_sgl_write.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t, + c.c_void_p) lib.vfu_create_ioeventfd.argtypes = (c.c_void_p, c.c_uint32, c.c_int, c.c_size_t, c.c_uint32, c.c_uint32, @@ -683,30 +695,46 @@ class Client: def __init__(self, sock=None): self.sock = sock + self.cmd_sock = None - def connect(self, ctx): + def connect(self, ctx, + max_data_xfer_size=VFIO_USER_DEFAULT_MAX_DATA_XFER_SIZE, + cmd_conn=False): self.sock = connect_sock() - json = b'{ "capabilities": { "max_msg_fds": 8 } }' + json = f''' + {{ + "capabilities": {{ + "max_data_xfer_size": {max_data_xfer_size}, + "max_msg_fds": 8, + "cmd_conn": {str(cmd_conn).lower()} + }} + }} + ''' # struct vfio_user_version payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR, - LIBVFIO_USER_MINOR, json, b'\0') + LIBVFIO_USER_MINOR, json.encode(), b'\0') hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload)) self.sock.send(hdr + payload) vfu_attach_ctx(ctx, expect=0) - payload = get_reply(self.sock, expect=0) + fds, payload = get_reply_fds(self.sock, expect=0) + self.cmd_sock = socket.socket(fileno=fds[0]) if fds else None + return self.sock def disconnect(self, ctx): self.sock.close() self.sock = None + if self.cmd_sock is not None: + self.cmd_sock.close() + self.cmd_sock = None # notice client closed connection vfu_run_ctx(ctx, errno.ENOTCONN) -def connect_client(ctx): +def connect_client(*args, **kwargs): client = Client() - client.connect(ctx) + client.connect(*args, **kwargs) return client @@ -718,6 +746,23 @@ def get_reply(sock, expect=0): return buf[16:] +def send_msg(sock, cmd, msg_type, payload=bytearray(), fds=None, msg_id=None, + error_no=0): + """ + Sends a message on the given socket. Can be used on either end of the + socket to send commands and replies. + """ + hdr = vfio_user_header(cmd, size=len(payload), msg_type=msg_type, + msg_id=msg_id, error=error_no != 0, + error_no=error_no) + + if fds: + sock.sendmsg([hdr + payload], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, + struct.pack("I" * len(fds), *fds))]) + else: + sock.send(hdr + payload) + + def msg(ctx, sock, cmd, payload=bytearray(), expect=0, fds=None, rsp=True, busy=False): """ @@ -730,13 +775,7 @@ def msg(ctx, sock, cmd, payload=bytearray(), expect=0, fds=None, response: it can later be retrieved, post vfu_device_quiesced(), with get_reply(). """ - hdr = vfio_user_header(cmd, size=len(payload)) - - if fds: - sock.sendmsg([hdr + payload], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, - struct.pack("I" * len(fds), *fds))]) - else: - sock.send(hdr + payload) + send_msg(sock, cmd, VFIO_USER_F_TYPE_COMMAND, payload, fds) if busy: vfu_run_ctx(ctx, errno.EBUSY) @@ -749,12 +788,14 @@ def msg(ctx, sock, cmd, payload=bytearray(), expect=0, fds=None, return get_reply(sock, expect=expect) -def get_reply_fds(sock, expect=0): - """Receives a message from a socket and pulls the returned file descriptors - out of the message.""" +def get_msg_fds(sock, expect_type, expect=0): + """ + Receives a message from a socket and pulls the returned file descriptors + out of the message. + """ fds = array.array("i") - data, ancillary, flags, addr = sock.recvmsg(4096, - socket.CMSG_LEN(64 * fds.itemsize)) + data, ancillary, flags, addr = sock.recvmsg(SERVER_MAX_MSG_SIZE, + socket.CMSG_LEN(64 * fds.itemsize)) (msg_id, cmd, msg_size, msg_flags, errno) = struct.unpack("HHIII", data[0:16]) assert errno == expect @@ -766,8 +807,18 @@ def get_reply_fds(sock, expect=0): [unpacked_fd] = struct.unpack_from("i", packed_fd, offset=i) unpacked_fds.append(unpacked_fd) assert len(packed_fd)/4 == len(unpacked_fds) - assert (msg_flags & VFIO_USER_F_TYPE_REPLY) != 0 - return (unpacked_fds, data[16:]) + assert (msg_flags & 0xf) == expect_type + return (unpacked_fds, msg_id, cmd, data[16:]) + + +def get_reply_fds(sock, expect=0): + """ + Receives a reply from a socket and returns the included file descriptors + and message payload data. + """ + (unpacked_fds, _, _, data) = get_msg_fds(sock, VFIO_USER_F_TYPE_REPLY, + expect) + return (unpacked_fds, data) def msg_fds(ctx, sock, cmd, payload, expect=0, fds=None): @@ -966,7 +1017,7 @@ def prepare_ctx_for_dma(dma_register=__dma_register, # -msg_id = 1 +next_msg_id = 1 @c.CFUNCTYPE(None, c.c_void_p, c.c_int, c.c_char_p) @@ -982,13 +1033,21 @@ def log(ctx, level, msg): print(lvl2str[level] + ": " + msg.decode("utf-8")) -def vfio_user_header(cmd, size, no_reply=False, error=False, error_no=0): - global msg_id +def vfio_user_header(cmd, + size, + msg_type=VFIO_USER_F_TYPE_COMMAND, + msg_id=None, + no_reply=False, + error=False, + error_no=0): + global next_msg_id - buf = struct.pack("HHIII", msg_id, cmd, SIZEOF_VFIO_USER_HEADER + size, - VFIO_USER_F_TYPE_COMMAND, error_no) + if msg_id is None: + msg_id = next_msg_id + next_msg_id += 1 - msg_id += 1 + buf = struct.pack("HHIII", msg_id, cmd, SIZEOF_VFIO_USER_HEADER + size, + msg_type | (no_reply << 4) | (error << 5), error_no) return buf @@ -1230,6 +1289,18 @@ def vfu_sgl_put(ctx, sg, iovec, cnt=1): return lib.vfu_sgl_put(ctx, sg, iovec, cnt) +def vfu_sgl_read(ctx, sg, cnt=1): + data = bytearray(sum([sge.length for sge in sg])) + buf = (c.c_byte * len(data)).from_buffer(data) + return lib.vfu_sgl_read(ctx, sg, cnt, buf), data + + +def vfu_sgl_write(ctx, sg, cnt=1, data=bytearray()): + assert len(data) == sum([sge.length for sge in sg]) + buf = (c.c_byte * len(data)).from_buffer(data) + return lib.vfu_sgl_write(ctx, sg, cnt, buf) + + def vfu_create_ioeventfd(ctx, region_idx, fd, gpa_offset, size, flags, datamatch, shadow_fd=-1, shadow_offset=0): assert ctx is not None diff --git a/test/py/meson.build b/test/py/meson.build index 0ea9f08b..ecd2fe2c 100644 --- a/test/py/meson.build +++ b/test/py/meson.build @@ -45,6 +45,7 @@ python_tests = [ 'test_request_errors.py', 'test_setup_region.py', 'test_sgl_get_put.py', + 'test_sgl_read_write.py', 'test_vfu_create_ctx.py', 'test_vfu_realize_ctx.py', ] diff --git a/test/py/test_sgl_read_write.py b/test/py/test_sgl_read_write.py new file mode 100644 index 00000000..8334ea44 --- /dev/null +++ b/test/py/test_sgl_read_write.py @@ -0,0 +1,182 @@ +# +# Copyright (c) 2023 Nutanix Inc. All rights reserved. +# Copyright (c) 2023 Rivos Inc. All rights reserved. +# +# Authors: Mattias Nissler +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Nutanix nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. +# + +from libvfio_user import * +import select +import threading + +MAP_ADDR = 0x10000000 +MAP_SIZE = 16 << PAGE_SHIFT + +ctx = None +client = None + + +class DMARegionHandler: + """ + A helper to service DMA region accesses arriving over a socket. Accesses + are performed against an internal bytearray buffer. DMA request processing + takes place on a separate thread so as to not block the test code. + """ + + def handle_requests(sock, pipe, buf, lock, addr, error_no): + while True: + (ready, _, _) = select.select([sock, pipe], [], []) + if pipe in ready: + break + + # Read a command from the socket and service it. + _, msg_id, cmd, payload = get_msg_fds(sock, + VFIO_USER_F_TYPE_COMMAND) + assert cmd in [VFIO_USER_DMA_READ, VFIO_USER_DMA_WRITE] + access, data = vfio_user_dma_region_access.pop_from_buffer(payload) + + assert access.addr >= addr + assert access.addr + access.count <= addr + len(buf) + + offset = access.addr - addr + with lock: + if cmd == VFIO_USER_DMA_READ: + data = buf[offset:offset + access.count] + else: + buf[offset:offset + access.count] = data + data = bytearray() + + send_msg(sock, + cmd, + VFIO_USER_F_TYPE_REPLY, + payload=payload[:c.sizeof(access)] + data, + msg_id=msg_id, + error_no=error_no) + + os.close(pipe) + sock.close() + + def __init__(self, sock, addr, size, error_no=0): + self.data = bytearray(size) + self.data_lock = threading.Lock() + self.addr = addr + (pipe_r, self.pipe_w) = os.pipe() + # Duplicate the socket file descriptor so the thread can own it and + # make sure it gets closed only when terminating the thread. + sock = socket.socket(fileno=os.dup(sock.fileno())) + thread = threading.Thread( + target=DMARegionHandler.handle_requests, + args=[sock, pipe_r, self.data, self.data_lock, addr, error_no]) + thread.start() + + def shutdown(self): + # Closing the pipe's write end will signal the thread to terminate. + os.close(self.pipe_w) + + def read(self, addr, size): + offset = addr - self.addr + with self.data_lock: + return self.data[offset:offset + size] + + +def setup_function(function): + global ctx, client, dma_handler + ctx = prepare_ctx_for_dma() + assert ctx is not None + client = connect_client(ctx, max_data_xfer_size=PAGE_SIZE, cmd_conn=True) + + payload = vfio_user_dma_map(argsz=len(vfio_user_dma_map()), + flags=(VFIO_USER_F_DMA_REGION_READ + | VFIO_USER_F_DMA_REGION_WRITE), + offset=0, + addr=MAP_ADDR, + size=MAP_SIZE) + + msg(ctx, client.sock, VFIO_USER_DMA_MAP, payload) + + dma_handler = DMARegionHandler(client.cmd_sock, payload.addr, payload.size) + + +def teardown_function(function): + dma_handler.shutdown() + client.disconnect(ctx) + vfu_destroy_ctx(ctx) + + +def test_dma_read_write(): + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=64, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + data = bytearray([x & 0xff for x in range(0, sg[0].length)]) + assert vfu_sgl_write(ctx, sg, 1, data) == 0 + + assert vfu_sgl_read(ctx, sg, 1) == (0, data) + + assert dma_handler.read(sg[0].dma_addr + sg[0].offset, + sg[0].length) == data + + +def test_dma_read_write_large(): + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=2 * PAGE_SIZE, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + data = bytearray([x & 0xff for x in range(0, sg[0].length)]) + assert vfu_sgl_write(ctx, sg, 1, data) == 0 + + assert vfu_sgl_read(ctx, sg, 1) == (0, data) + + assert dma_handler.read(sg[0].dma_addr + sg[0].offset, + sg[0].length) == data + + +def test_dma_read_write_error(): + # Reinitialize the handler to return EIO. + global dma_handler + dma_handler.shutdown() + dma_handler = DMARegionHandler(client.cmd_sock, MAP_ADDR, MAP_SIZE, + error_no=errno.EIO) + + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=64, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + ret, _ = vfu_sgl_read(ctx, sg, 1) + assert ret == -1 + assert c.get_errno() == errno.EIO + + +# ex: set tabstop=4 shiftwidth=4 softtabstop=4 expandtab: #