Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass server->client command over a separate socket pair #762

Merged
merged 15 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/libvfio-user.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ extern "C" {
#endif

#define LIB_VFIO_USER_MAJOR 0
#define LIB_VFIO_USER_MINOR 1
#define LIB_VFIO_USER_MINOR 2

/* DMA addresses cannot be directly de-referenced. */
typedef void *vfu_dma_addr_t;
Expand Down
95 changes: 85 additions & 10 deletions lib/tran.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>

#include <json.h>

Expand All @@ -52,9 +53,13 @@
* {
* "capabilities": {
* "max_msg_fds": 32,
* "max_data_xfer_size": 1048576
* "max_data_xfer_size": 1048576,
* "migration": {
* "pgsize": 4096
* },
* "twin_socket": {
* "supported": true,
* "fd_index": 0
* }
* }
* }
Expand All @@ -64,7 +69,8 @@
*/
mnissler-rivos marked this conversation as resolved.
Show resolved Hide resolved
int
tran_parse_version_json(const char *json_str, int *client_max_fdsp,
size_t *client_max_data_xfer_sizep, size_t *pgsizep)
size_t *client_max_data_xfer_sizep, size_t *pgsizep,
bool *twin_socket_supportedp)
{
struct json_object *jo_caps = NULL;
struct json_object *jo_top = NULL;
Expand Down Expand Up @@ -130,6 +136,27 @@ tran_parse_version_json(const char *json_str, int *client_max_fdsp,
}
}

if (json_object_object_get_ex(jo_caps, "twin_socket", &jo)) {
struct json_object *jo2 = NULL;

if (json_object_get_type(jo) != json_type_object) {
goto out;
}

if (json_object_object_get_ex(jo, "supported", &jo2)) {
if (json_object_get_type(jo2) != json_type_boolean) {
goto out;
}

errno = 0;
*twin_socket_supportedp = json_object_get_boolean(jo2);

if (errno != 0) {
goto out;
}
}
}

ret = 0;

out:
Expand All @@ -143,7 +170,7 @@ tran_parse_version_json(const char *json_str, int *client_max_fdsp,

static int
recv_version(vfu_ctx_t *vfu_ctx, uint16_t *msg_idp,
struct vfio_user_version **versionp)
struct vfio_user_version **versionp, bool *twin_socket_supportedp)
{
struct vfio_user_version *cversion = NULL;
vfu_msg_t msg = { { 0 } };
Expand Down Expand Up @@ -208,7 +235,7 @@ recv_version(vfu_ctx_t *vfu_ctx, uint16_t *msg_idp,

ret = tran_parse_version_json(json_str, &vfu_ctx->client_max_fds,
&vfu_ctx->client_max_data_xfer_size,
&pgsize);
&pgsize, twin_socket_supportedp);

if (ret < 0) {
/* No client-supplied strings in the log for release build. */
Expand Down Expand Up @@ -312,8 +339,9 @@ json_add_uint64(struct json_object *jso, const char *key, uint64_t value)
* be freed by the caller.
*/
static char *
format_server_capabilities(vfu_ctx_t *vfu_ctx)
format_server_capabilities(vfu_ctx_t *vfu_ctx, int twin_socket_fd_index)
{
struct json_object *jo_twin_socket = NULL;
struct json_object *jo_migration = NULL;
struct json_object *jo_caps = NULL;
struct json_object *jo_top = NULL;
Expand Down Expand Up @@ -347,6 +375,25 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx)
}
}

if (twin_socket_fd_index >= 0) {
struct json_object *jo_supported = NULL;

if ((jo_twin_socket = json_object_new_object()) == NULL) {
goto out;
}

if ((jo_supported = json_object_new_boolean(true)) == NULL ||
json_add(jo_twin_socket, "supported", &jo_supported) < 0 ||
json_add_uint64(jo_twin_socket, "fd_index",
twin_socket_fd_index) < 0) {
goto out;
}

if (json_add(jo_caps, "twin_socket", &jo_twin_socket) < 0) {
goto out;
}
}

if ((jo_top = json_object_new_object()) == NULL ||
json_add(jo_top, "capabilities", &jo_caps) < 0) {
goto out;
Expand All @@ -355,6 +402,7 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx)
caps_str = strdup(json_object_to_json_string(jo_top));

out:
json_object_put(jo_twin_socket);
json_object_put(jo_migration);
json_object_put(jo_caps);
json_object_put(jo_top);
Expand All @@ -363,15 +411,17 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx)

static int
send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
struct vfio_user_version *cversion)
struct vfio_user_version *cversion, int client_cmd_socket_fd)
{
int twin_socket_fd_index = client_cmd_socket_fd >= 0 ? 0 : -1;
struct vfio_user_version sversion = { 0 };
struct iovec iovecs[2] = { { 0 } };
vfu_msg_t msg = { { 0 } };
char *server_caps = NULL;
int ret;

if ((server_caps = format_server_capabilities(vfu_ctx)) == NULL) {
server_caps = format_server_capabilities(vfu_ctx, twin_socket_fd_index);
if (server_caps == NULL) {
errno = ENOMEM;
return -1;
}
Expand All @@ -391,32 +441,57 @@ send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
msg.hdr.msg_id = msg_id;
msg.out_iovecs = iovecs;
mnissler-rivos marked this conversation as resolved.
Show resolved Hide resolved
msg.nr_out_iovecs = 2;
if (client_cmd_socket_fd >= 0) {
msg.out.fds = &client_cmd_socket_fd;
msg.out.nr_fds = 1;
assert(msg.out.fds[twin_socket_fd_index] == client_cmd_socket_fd);
}

ret = vfu_ctx->tran->reply(vfu_ctx, &msg, 0);
free(server_caps);
return ret;
mnissler-rivos marked this conversation as resolved.
Show resolved Hide resolved
}

int
tran_negotiate(vfu_ctx_t *vfu_ctx)
tran_negotiate(vfu_ctx_t *vfu_ctx, int *client_cmd_socket_fdp)
{
struct vfio_user_version *client_version = NULL;
int client_cmd_socket_fds[2] = { -1, -1 };
bool twin_socket_supported = false;
uint16_t msg_id = 0x0bad;
int ret;

ret = recv_version(vfu_ctx, &msg_id, &client_version);
ret = recv_version(vfu_ctx, &msg_id, &client_version,
&twin_socket_supported);

if (ret < 0) {
vfu_log(vfu_ctx, LOG_ERR, "failed to recv version: %m");
return ret;
}

ret = send_version(vfu_ctx, msg_id, client_version);
if (twin_socket_supported && client_cmd_socket_fdp != NULL &&
vfu_ctx->client_max_fds > 0) {
if (socketpair(AF_UNIX, SOCK_STREAM, 0, client_cmd_socket_fds) == -1) {
vfu_log(vfu_ctx, LOG_ERR, "failed to create cmd socket: %m");
return -1;
mnissler-rivos marked this conversation as resolved.
Show resolved Hide resolved
}
}

ret = send_version(vfu_ctx, msg_id, client_version,
client_cmd_socket_fds[0]);

free(client_version);

/*
* The remote end of the client command socket pair is no longer needed.
* The local end is kept only if passed to the caller on successful return.
*/
close_safely(&client_cmd_socket_fds[0]);
if (ret < 0) {
vfu_log(vfu_ctx, LOG_ERR, "failed to send version: %m");
close_safely(&client_cmd_socket_fds[1]);
} else if (client_cmd_socket_fdp != NULL) {
*client_cmd_socket_fdp = client_cmd_socket_fds[1];
}

return ret;
Expand Down
5 changes: 3 additions & 2 deletions lib/tran.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ struct transport_ops {
*/
int
tran_parse_version_json(const char *json_str, int *client_max_fdsp,
size_t *client_max_data_xfer_sizep, size_t *pgsizep);
size_t *client_max_data_xfer_sizep, size_t *pgsizep,
bool *twin_socket_supportedp);

int
tran_negotiate(vfu_ctx_t *vfu_ctx);
tran_negotiate(vfu_ctx_t *vfu_ctx, int *client_cmd_socket_fdp);

#endif /* LIB_VFIO_USER_TRAN_H */

Expand Down
2 changes: 1 addition & 1 deletion lib/tran_pipe.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ tran_pipe_attach(vfu_ctx_t *vfu_ctx)
tp->in_fd = STDIN_FILENO;
tp->out_fd = STDOUT_FILENO;

ret = tran_negotiate(vfu_ctx);
ret = tran_negotiate(vfu_ctx, NULL);
if (ret < 0) {
ret = errno;
tp->in_fd = -1;
Expand Down
31 changes: 28 additions & 3 deletions lib/tran_sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
typedef struct {
int listen_fd;
int conn_fd;
int client_cmd_socket_fd;
} tran_sock_t;

int
Expand Down Expand Up @@ -380,6 +381,7 @@ tran_sock_init(vfu_ctx_t *vfu_ctx)

ts->listen_fd = -1;
ts->conn_fd = -1;
ts->client_cmd_socket_fd = -1;

if ((ts->listen_fd = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) {
ret = errno;
Expand Down Expand Up @@ -464,7 +466,7 @@ tran_sock_attach(vfu_ctx_t *vfu_ctx)
return -1;
}

ret = tran_negotiate(vfu_ctx);
ret = tran_negotiate(vfu_ctx, &ts->client_cmd_socket_fd);
if (ret < 0) {
close_safely(&ts->conn_fd);
return -1;
Expand Down Expand Up @@ -607,6 +609,21 @@ tran_sock_reply(vfu_ctx_t *vfu_ctx, vfu_msg_t *msg, int err)
return ret;
}

static void maybe_print_cmd_collision_warning(vfu_ctx_t *vfu_ctx) {
static bool warning_printed = false;
static const char *warning_msg =
"You are using libvfio-user in a configuration that issues "
"client-to-server commands, but without the twin_socket feature "
"enabled. This is known to break when client and server send a command "
"at the same time. See "
"https://github.com/nutanix/libvfio-user/issues/279 for details.";

if (!warning_printed) {
vfu_log(vfu_ctx, LOG_WARNING, "%s", warning_msg);
warning_printed = true;
}
}

static int
tran_sock_send_msg(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
enum vfio_user_command cmd,
Expand All @@ -615,14 +632,21 @@ tran_sock_send_msg(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
void *recv_data, size_t recv_len)
{
tran_sock_t *ts;
int fd;

assert(vfu_ctx != NULL);
assert(vfu_ctx->tran_data != NULL);

ts = vfu_ctx->tran_data;

return tran_sock_msg(ts->conn_fd, msg_id, cmd, send_data, send_len,
hdr, recv_data, recv_len);
fd = ts->client_cmd_socket_fd;
if (fd == -1) {
maybe_print_cmd_collision_warning(vfu_ctx);
fd = ts->conn_fd;
}

return tran_sock_msg(fd, msg_id, cmd, send_data, send_len, hdr, recv_data,
recv_len);
}

static void
Expand All @@ -636,6 +660,7 @@ tran_sock_detach(vfu_ctx_t *vfu_ctx)

if (ts != NULL) {
close_safely(&ts->conn_fd);
close_safely(&ts->client_cmd_socket_fd);
}
}

Expand Down
2 changes: 1 addition & 1 deletion samples/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ recv_version(int sock, int *server_max_fds, size_t *server_max_data_xfer_size,
}

ret = tran_parse_version_json(json_str, server_max_fds,
server_max_data_xfer_size, pgsize);
server_max_data_xfer_size, pgsize, NULL);

if (ret < 0) {
err(EXIT_FAILURE, "failed to parse server JSON \"%s\"", json_str);
Expand Down
Loading