Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HG bulk: fix possible erroneous refcount when bulk create/transfer fails #772

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 31 additions & 60 deletions src/mercury_bulk.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,26 @@

/* Check permission flags */
#define HG_BULK_CHECK_FLAGS(op, origin_flags, local_flags, label, ret) \
switch (op) { \
case HG_BULK_PUSH: \
HG_CHECK_SUBSYS_ERROR(bulk, \
!(origin_flags & HG_BULK_WRITE_ONLY) || \
!(local_flags & HG_BULK_READ_ONLY), \
label, ret, HG_PERMISSION, \
"Invalid permission flags for PUSH operation " \
"(origin=0x%x, local=0x%x)", \
origin_flags, local_flags); \
break; \
case HG_BULK_PULL: \
do { \
HG_CHECK_SUBSYS_ERROR(bulk, op > HG_BULK_PULL, label, ret, \
HG_INVALID_ARG, "Unknown bulk operation"); \
if (op & HG_BULK_PULL) \
HG_CHECK_SUBSYS_ERROR(bulk, \
!(origin_flags & HG_BULK_READ_ONLY) || \
!(local_flags & HG_BULK_WRITE_ONLY), \
label, ret, HG_PERMISSION, \
"Invalid permission flags for PULL operation " \
"(origin=%d, local=%d)", \
origin_flags, local_flags); \
break; \
default: \
HG_GOTO_SUBSYS_ERROR( \
bulk, label, ret, HG_INVALID_ARG, "Unknown bulk operation"); \
}
else \
HG_CHECK_SUBSYS_ERROR(bulk, \
!(origin_flags & HG_BULK_WRITE_ONLY) || \
!(local_flags & HG_BULK_READ_ONLY), \
label, ret, HG_PERMISSION, \
"Invalid permission flags for PUSH operation " \
"(origin=0x%x, local=0x%x)", \
origin_flags, local_flags); \
} while (0)

/************************************/
/* Local Type and Struct Definition */
Expand Down Expand Up @@ -382,7 +379,7 @@ hg_bulk_transfer(hg_core_context_t *core_context, hg_cb_t callback, void *arg,
/**
* Bulk transfer to self.
*/
static hg_return_t
static void
hg_bulk_transfer_self(hg_bulk_op_t op,
const struct hg_bulk_segment *origin_segments, uint32_t origin_count,
hg_size_t origin_offset, const struct hg_bulk_segment *local_segments,
Expand Down Expand Up @@ -541,6 +538,7 @@ hg_bulk_create(hg_core_class_t *core_class, uint32_t count, void **bufs,
hg_bulk->desc.info.flags = flags;
hg_bulk->attrs = *attrs;
hg_atomic_init32(&hg_bulk->ref_count, 1);
hg_core_bulk_incr(core_class);

if (count > HG_BULK_STATIC_MAX) {
/* Allocate segments */
Expand Down Expand Up @@ -652,7 +650,6 @@ hg_bulk_create(hg_core_class_t *core_class, uint32_t count, void **bufs,
#endif
}
hg_bulk->registered = true;
hg_core_bulk_incr(core_class);

*hg_bulk_p = hg_bulk;

Expand Down Expand Up @@ -1288,6 +1285,7 @@ hg_bulk_deserialize(hg_core_class_t *core_class, struct hg_bulk **hg_bulk_p,
hg_bulk->na_class = HG_Core_class_get_na(core_class);
hg_bulk->registered = false;
hg_atomic_init32(&hg_bulk->ref_count, 1);
hg_core_bulk_incr(hg_bulk->core_class);

/* Descriptor info */
HG_BULK_DECODE(error, ret, buf_ptr, buf_size_left, &hg_bulk->desc.info,
Expand Down Expand Up @@ -1434,7 +1432,6 @@ hg_bulk_deserialize(hg_core_class_t *core_class, struct hg_bulk **hg_bulk_p,
"Buffer size left for decoding bulk handle is not zero (%" PRIu64 ")",
buf_size_left);

hg_core_bulk_incr(hg_bulk->core_class);
*hg_bulk_p = hg_bulk;

return HG_SUCCESS;
Expand Down Expand Up @@ -1971,15 +1968,13 @@ hg_bulk_transfer(hg_core_context_t *core_context, hg_cb_t callback, void *arg,
/* Complete immediately */
hg_bulk_complete(hg_bulk_op_id, HG_SUCCESS, true);
} else if (HG_Core_addr_is_self(origin_addr) ||
((origin_flags & HG_BULK_EAGER) && (op != HG_BULK_PUSH))) {
((origin_flags & HG_BULK_EAGER) && (op & HG_BULK_PULL))) {
hg_bulk_op_id->na_class = NULL;
hg_bulk_op_id->na_context = NULL;

/* When doing eager transfers, use self code path to copy data locally
*/
ret = hg_bulk_transfer_self(op, origin_segments, origin_count,
origin_offset, local_segments, local_count, local_offset, size,
hg_bulk_op_id);
/* For eager transfers, use self code path to copy data locally */
hg_bulk_transfer_self(op, origin_segments, origin_count, origin_offset,
local_segments, local_count, local_offset, size, hg_bulk_op_id);
} else {
struct hg_bulk_na_mem_desc *origin_mem_descs, *local_mem_descs;
na_mem_handle_t **origin_mem_handles, **local_mem_handles;
Expand Down Expand Up @@ -2018,6 +2013,7 @@ hg_bulk_transfer(hg_core_context_t *core_context, hg_cb_t callback, void *arg,
origin_segments, origin_count, origin_mem_handles, origin_flags,
origin_offset, local_segments, local_count, local_mem_handles,
local_flags, local_offset, size, hg_bulk_op_id);
HG_CHECK_SUBSYS_HG_ERROR(bulk, error, ret, "Could not transfer data");
}

/* Assign op_id */
Expand All @@ -2027,14 +2023,18 @@ hg_bulk_transfer(hg_core_context_t *core_context, hg_cb_t callback, void *arg,
return HG_SUCCESS;

error:
if (hg_bulk_op_id)
if (hg_bulk_op_id != NULL) {
/* decrement ref_count */
(void) hg_bulk_free(hg_bulk_origin);
(void) hg_bulk_free(hg_bulk_local);
hg_bulk_op_destroy(hg_bulk_op_id);
}

return ret;
}

/*---------------------------------------------------------------------------*/
static hg_return_t
static void
hg_bulk_transfer_self(hg_bulk_op_t op,
const struct hg_bulk_segment *origin_segments, uint32_t origin_count,
hg_size_t origin_offset, const struct hg_bulk_segment *local_segments,
Expand All @@ -2043,20 +2043,8 @@ hg_bulk_transfer_self(hg_bulk_op_t op,
{
uint32_t origin_segment_start_index = 0, local_segment_start_index = 0;
hg_size_t origin_segment_start_offset = 0, local_segment_start_offset = 0;
hg_bulk_copy_op_t copy_op;
hg_return_t ret;

switch (op) {
case HG_BULK_PUSH:
copy_op = hg_bulk_memcpy_put;
break;
case HG_BULK_PULL:
copy_op = hg_bulk_memcpy_get;
break;
default:
HG_GOTO_SUBSYS_ERROR(
bulk, error, ret, HG_INVALID_ARG, "Unknown bulk operation");
}
hg_bulk_copy_op_t copy_op =
(op & HG_BULK_PULL) ? hg_bulk_memcpy_get : hg_bulk_memcpy_put;

HG_LOG_SUBSYS_DEBUG(bulk, "Transferring data through self");

Expand All @@ -2078,11 +2066,6 @@ hg_bulk_transfer_self(hg_bulk_op_t op,

/* Complete immediately */
hg_bulk_complete(hg_bulk_op_id, HG_SUCCESS, true);

return HG_SUCCESS;

error:
return ret;
}

/*---------------------------------------------------------------------------*/
Expand Down Expand Up @@ -2149,22 +2132,10 @@ hg_bulk_transfer_na(hg_bulk_op_t op, na_addr_t *na_origin_addr,
hg_size_t local_offset, hg_size_t size, struct hg_bulk_op_id *hg_bulk_op_id)
{
hg_bulk_na_op_id_t *hg_bulk_na_op_ids;
na_bulk_op_t na_bulk_op;
na_bulk_op_t na_bulk_op =
(op & HG_BULK_PULL) ? hg_bulk_na_get : hg_bulk_na_put;
hg_return_t ret;

/* Map op to NA op */
switch (op) {
case HG_BULK_PUSH:
na_bulk_op = hg_bulk_na_put;
break;
case HG_BULK_PULL:
na_bulk_op = hg_bulk_na_get;
break;
default:
HG_GOTO_SUBSYS_ERROR(
bulk, error, ret, HG_INVALID_ARG, "Unknown bulk operation");
}

#ifdef NA_HAS_SM
/* Use NA SM op IDs if needed */
if (origin_flags & HG_BULK_SM)
Expand Down
Loading