Skip to content

Commit

Permalink
Add pytest for vfu_sgl_{read,write}
Browse files Browse the repository at this point in the history
The new tests verify behavior of vfu_sgl_{read,write} for success and
error cases.
  • Loading branch information
mnissler-rivos committed Aug 15, 2023
1 parent 937f09a commit 9f3b8d4
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 26 deletions.
123 changes: 97 additions & 26 deletions test/py/libvfio_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ class vfio_user_dma_unmap(Structure):
]


class vfio_user_dma_region_access(Structure):
_pack_ = 1
_fields_ = [
("addr", c.c_uint64),
("count", c.c_uint64),
]


class vfu_dma_info_t(Structure):
_fields_ = [
("iova", iovec_t),
Expand Down Expand Up @@ -632,6 +640,10 @@ class vfio_user_migration_info(Structure):
c.POINTER(iovec_t), c.c_size_t, c.c_int)
lib.vfu_sgl_put.argtypes = (c.c_void_p, c.POINTER(dma_sg_t),
c.POINTER(iovec_t), c.c_size_t)
lib.vfu_sgl_read.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t,
c.c_void_p)
lib.vfu_sgl_write.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t,
c.c_void_p)

lib.vfu_create_ioeventfd.argtypes = (c.c_void_p, c.c_uint32, c.c_int,
c.c_size_t, c.c_uint32, c.c_uint32,
Expand Down Expand Up @@ -683,30 +695,46 @@ class Client:

def __init__(self, sock=None):
self.sock = sock
self.cmd_sock = None

def connect(self, ctx):
def connect(self, ctx,
max_data_xfer_size=VFIO_USER_DEFAULT_MAX_DATA_XFER_SIZE,
cmd_conn=False):
self.sock = connect_sock()

json = b'{ "capabilities": { "max_msg_fds": 8 } }'
json = f'''
{{
"capabilities": {{
"max_data_xfer_size": {max_data_xfer_size},
"max_msg_fds": 8,
"cmd_conn": {str(cmd_conn).lower()}
}}
}}
'''
# struct vfio_user_version
payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR,
LIBVFIO_USER_MINOR, json, b'\0')
LIBVFIO_USER_MINOR, json.encode(), b'\0')
hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload))
self.sock.send(hdr + payload)
vfu_attach_ctx(ctx, expect=0)
payload = get_reply(self.sock, expect=0)
fds, payload = get_reply_fds(self.sock, expect=0)
self.cmd_sock = socket.socket(fileno=fds[0]) if fds else None
return self.sock

def disconnect(self, ctx):
self.sock.close()
self.sock = None
if self.cmd_sock is not None:
self.cmd_sock.close()
self.cmd_sock = None

# notice client closed connection
vfu_run_ctx(ctx, errno.ENOTCONN)


def connect_client(ctx):
def connect_client(*args, **kwargs):
client = Client()
client.connect(ctx)
client.connect(*args, **kwargs)
return client


Expand All @@ -718,6 +746,23 @@ def get_reply(sock, expect=0):
return buf[16:]


def send_msg(sock, cmd, msg_type, payload=bytearray(), fds=None, msg_id=None,
error_no=0):
"""
Sends a message on the given socket. Can be used on either end of the
socket to send commands and replies.
"""
hdr = vfio_user_header(cmd, size=len(payload), msg_type=msg_type,
msg_id=msg_id, error=error_no != 0,
error_no=error_no)

if fds:
sock.sendmsg([hdr + payload], [(socket.SOL_SOCKET, socket.SCM_RIGHTS,
struct.pack("I" * len(fds), *fds))])
else:
sock.send(hdr + payload)


def msg(ctx, sock, cmd, payload=bytearray(), expect=0, fds=None,
rsp=True, busy=False):
"""
Expand All @@ -730,13 +775,7 @@ def msg(ctx, sock, cmd, payload=bytearray(), expect=0, fds=None,
response: it can later be retrieved, post vfu_device_quiesced(), with
get_reply().
"""
hdr = vfio_user_header(cmd, size=len(payload))

if fds:
sock.sendmsg([hdr + payload], [(socket.SOL_SOCKET, socket.SCM_RIGHTS,
struct.pack("I" * len(fds), *fds))])
else:
sock.send(hdr + payload)
send_msg(sock, cmd, VFIO_USER_F_TYPE_COMMAND, payload, fds)

if busy:
vfu_run_ctx(ctx, errno.EBUSY)
Expand All @@ -749,12 +788,14 @@ def msg(ctx, sock, cmd, payload=bytearray(), expect=0, fds=None,
return get_reply(sock, expect=expect)


def get_reply_fds(sock, expect=0):
"""Receives a message from a socket and pulls the returned file descriptors
out of the message."""
def get_msg_fds(sock, expect_type, expect=0):
"""
Receives a message from a socket and pulls the returned file descriptors
out of the message.
"""
fds = array.array("i")
data, ancillary, flags, addr = sock.recvmsg(4096,
socket.CMSG_LEN(64 * fds.itemsize))
data, ancillary, flags, addr = sock.recvmsg(SERVER_MAX_MSG_SIZE,
socket.CMSG_LEN(64 * fds.itemsize))
(msg_id, cmd, msg_size, msg_flags, errno) = struct.unpack("HHIII",
data[0:16])
assert errno == expect
Expand All @@ -766,8 +807,18 @@ def get_reply_fds(sock, expect=0):
[unpacked_fd] = struct.unpack_from("i", packed_fd, offset=i)
unpacked_fds.append(unpacked_fd)
assert len(packed_fd)/4 == len(unpacked_fds)
assert (msg_flags & VFIO_USER_F_TYPE_REPLY) != 0
return (unpacked_fds, data[16:])
assert (msg_flags & 0xf) == expect_type
return (unpacked_fds, msg_id, cmd, data[16:])


def get_reply_fds(sock, expect=0):
"""
Receives a reply from a socket and returns the included file descriptors
and message payload data.
"""
(unpacked_fds, _, _, data) = get_msg_fds(sock, VFIO_USER_F_TYPE_REPLY,
expect)
return (unpacked_fds, data)


def msg_fds(ctx, sock, cmd, payload, expect=0, fds=None):
Expand Down Expand Up @@ -966,7 +1017,7 @@ def prepare_ctx_for_dma(dma_register=__dma_register,
#


msg_id = 1
next_msg_id = 1


@c.CFUNCTYPE(None, c.c_void_p, c.c_int, c.c_char_p)
Expand All @@ -982,13 +1033,21 @@ def log(ctx, level, msg):
print(lvl2str[level] + ": " + msg.decode("utf-8"))


def vfio_user_header(cmd, size, no_reply=False, error=False, error_no=0):
global msg_id
def vfio_user_header(cmd,
size,
msg_type=VFIO_USER_F_TYPE_COMMAND,
msg_id=None,
no_reply=False,
error=False,
error_no=0):
global next_msg_id

buf = struct.pack("HHIII", msg_id, cmd, SIZEOF_VFIO_USER_HEADER + size,
VFIO_USER_F_TYPE_COMMAND, error_no)
if msg_id is None:
msg_id = next_msg_id
next_msg_id += 1

msg_id += 1
buf = struct.pack("HHIII", msg_id, cmd, SIZEOF_VFIO_USER_HEADER + size,
msg_type | (no_reply << 4) | (error << 5), error_no)

return buf

Expand Down Expand Up @@ -1230,6 +1289,18 @@ def vfu_sgl_put(ctx, sg, iovec, cnt=1):
return lib.vfu_sgl_put(ctx, sg, iovec, cnt)


def vfu_sgl_read(ctx, sg, cnt=1):
data = bytearray(sum([sge.length for sge in sg]))
buf = (c.c_byte * len(data)).from_buffer(data)
return lib.vfu_sgl_read(ctx, sg, cnt, buf), data


def vfu_sgl_write(ctx, sg, cnt=1, data=bytearray()):
assert len(data) == sum([sge.length for sge in sg])
buf = (c.c_byte * len(data)).from_buffer(data)
return lib.vfu_sgl_write(ctx, sg, cnt, buf)


def vfu_create_ioeventfd(ctx, region_idx, fd, gpa_offset, size, flags,
datamatch, shadow_fd=-1, shadow_offset=0):
assert ctx is not None
Expand Down
1 change: 1 addition & 0 deletions test/py/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python_tests = [
'test_request_errors.py',
'test_setup_region.py',
'test_sgl_get_put.py',
'test_sgl_read_write.py',
'test_vfu_create_ctx.py',
'test_vfu_realize_ctx.py',
]
Expand Down
Loading

0 comments on commit 9f3b8d4

Please sign in to comment.