Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coll: remove errflag propagation #7045

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions maint/gen_coll.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,6 @@ def get_algo_args(args, algo, kind):
algo_args += ", *sched_p"
elif algo['func-commkind'].startswith('i'):
algo_args += ", *sched_p"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_args += ", errflag"

return algo_args

Expand All @@ -665,8 +663,6 @@ def get_algo_params(params, algo):
algo_params += ", MPIR_TSP_sched_t sched"
elif algo['func-commkind'].startswith('i'):
algo_params += ", MPIR_Sched_t s"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_params += ", MPIR_Errflag_t errflag"

return algo_params

Expand All @@ -682,8 +678,7 @@ def get_algo_name(algo):
def get_func_params(params, name, kind):
func_params = params
if kind == "blocking":
if not name.startswith('neighbor_'):
func_params += ", MPIR_Errflag_t errflag"
pass
elif kind == "nonblocking":
func_params += ", MPIR_Request ** request"
elif kind == "persistent":
Expand All @@ -702,8 +697,7 @@ def get_func_params(params, name, kind):
def get_func_args(args, name, kind):
func_args = args
if kind == "blocking":
if not name.startswith('neighbor_'):
func_args += ", errflag"
pass
elif kind == "nonblocking":
func_args += ", request"
elif kind == "persistent":
Expand Down
6 changes: 1 addition & 5 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,10 +1688,6 @@ def push_impl_decl(func, impl_name=None):

if func['_impl_param_list']:
params = ', '.join(func['_impl_param_list'])
if func['dir'] == 'coll':
# block collective use an extra errflag
if not RE.match(r'MPI_(I.*|Neighbor.*|.*_init)$', func['name']):
params = params + ", MPIR_Errflag_t errflag"
else:
params="void"

Expand Down Expand Up @@ -1744,7 +1740,7 @@ def dump_body_coll(func):
dump_error_check("")
else:
# blocking collectives
dump_line_with_break("mpi_errno = %s(%s, MPIR_ERR_NONE);" % (mpir_name, args))
dump_line_with_break("mpi_errno = %s(%s);" % (mpir_name, args))
dump_error_check("")

def dump_coll_v_swap(func):
Expand Down
2 changes: 1 addition & 1 deletion src/binding/c/comm_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ MPI_Intercomm_merge:
* error to make */
acthigh = high ? 1 : 0; /* Clamp high into 1 or 0 */
mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &acthigh, 1, MPI_INT,
MPI_SUM, intercomm_ptr->local_comm, MPIR_ERR_NONE);
MPI_SUM, intercomm_ptr->local_comm);
MPIR_ERR_CHECK(mpi_errno);
/* acthigh must either == 0 or the size of the local comm */
if (acthigh != 0 && acthigh != intercomm_ptr->local_size) {
Expand Down
15 changes: 7 additions & 8 deletions src/include/mpir_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,30 @@ int MPIC_Wait(MPIR_Request * request_ptr);
int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status);

int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr);
int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag,
MPIR_Comm * comm_ptr, MPI_Status * status);
int MPIC_Ssend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr);
int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
int dest, int sendtag, void *recvbuf, MPI_Aint recvcount,
MPI_Datatype recvtype, int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPI_Status * status);
int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int sendtag,
int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
int source, int recvtag, MPIR_Comm * comm_ptr, MPI_Status * status);
int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPIR_Request ** request);
int MPIC_Issend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPIR_Request ** request);
int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source,
int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request);
int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status * statuses);

int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op);

int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr);

/* TSP auto */
int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
Expand Down
26 changes: 2 additions & 24 deletions src/include/mpir_err.h
Original file line number Diff line number Diff line change
Expand Up @@ -907,23 +907,10 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in
(err_) = MPIR_Err_combine_codes((err_), (newerr_)); \
} while (0)

/* For collective communication error, update errflag_ and err_ret_, do not abort */
#define MPIR_ERR_COLL_CHECKANDCONT(err_, errflag_, err_ret_) \
#define MPIR_ERR_COLL_CHECK_SIZE(recv_sz, expect_sz, err_) \
do { \
if (err_) { \
errflag_ |= (MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(err_)) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; \
MPIR_ERR_ADD(err_ret_, err_); \
} \
} while (0)

/* Propagate the size mismatch error */
#define MPIR_ERR_COLL_CHECK_SIZE(recv_sz, expect_sz, errflag_, err_ret_) \
do { \
if (recv_sz != expect_sz) { \
int err = MPI_SUCCESS; \
MPIR_ERR_SET2(err, MPI_ERR_OTHER, "**collective_size_mismatch", "**collective_size_mismatch %d %d", recv_sz, expect_sz); \
MPIR_ERR_ADD(err_ret_, err); \
errflag_ |= MPIR_ERR_OTHER; \
MPIR_ERR_SETANDJUMP2(err_, MPI_ERR_OTHER, "**collective_size_mismatch", "**collective_size_mismatch %d %d", recv_sz, expect_sz); \
} \
} while (0)

Expand Down Expand Up @@ -992,15 +979,6 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in
err_ = newerr_; \
} while (0)

#define MPIR_ERR_COLL_CHECKANDCONT(err_, errflag_, err_ret_) \
do { \
if (err_) { \
errflag_ = (MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(err_)) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; \
} \
} while (0)

#define MPIR_ERR_COLL_CHECK_SIZE(recv_sz, expect_sz, errflag_, err_ret_) do { } while (0)

#endif

/* The following definitions are the same independent of the choice of
Expand Down
16 changes: 1 addition & 15 deletions src/include/mpir_pt2pt.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,7 @@
#define MPIR_PT2PT_ATTR_CONTEXT_OFFSET(attr) ((attr) & 0x01)
#define MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, context_offset) (attr) |= (context_offset)

/* bit 1-2: errflag */
#define MPIR_PT2PT_ATTR_GET_ERRFLAG(attr) \
((!((attr) & 0x6)) ? MPIR_ERR_NONE : \
(((attr) & 0x2) ? MPIX_ERR_PROC_FAILED : MPI_ERR_OTHER))

#define MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag) \
do { \
if (errflag) { \
if (errflag == MPIR_ERR_PROC_FAILED) { \
(attr) |= 0x2; \
} else { \
(attr) |= 0x4; \
} \
} \
} while (0)
/* bit 1-2: errflag (removed) */

/* bit 3: syncflag */
#define MPIR_PT2PT_ATTR_GET_SYNCFLAG(attr) (((attr) & 0x8) ? 1 : 0)
Expand Down
3 changes: 1 addition & 2 deletions src/mpi/coll/algorithms/treealgo/treeutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,8 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root,
} else {
/* rank level - build a tree on the ranks */
/* Do an allgather to know the current num_children on each rank */
MPIR_Errflag_t errflag = MPIR_ERR_NONE;
MPIR_Allgather_impl(&(ct->num_children), 1, MPI_INT, num_childrens, 1, MPI_INT,
comm, errflag);
comm);
if (mpi_errno) {
goto fn_fail;
}
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/allgather/allgather_allcomm_nb.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Request *req_ptr = NULL;
Expand Down
28 changes: 12 additions & 16 deletions src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr)
{
int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root;
int mpi_errno_ret = MPI_SUCCESS;
MPI_Aint sendtype_sz;
void *tmp_buf = NULL;
MPIR_Comm *newcomm_ptr = NULL;
Expand Down Expand Up @@ -48,8 +47,8 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint

if (sendcount != 0) {
mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz,
MPI_BYTE, 0, newcomm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPI_BYTE, 0, newcomm_ptr);
MPIR_ERR_CHECK(mpi_errno);
}

/* first broadcast from left to right group, then from right to
Expand All @@ -59,39 +58,36 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPI_BYTE, root, comm_ptr);
MPIR_ERR_CHECK(mpi_errno);
}

/* receive bcast from right */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size, recvtype, root, comm_ptr);
MPIR_ERR_CHECK(mpi_errno);
}
} else {
/* receive bcast from left */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size, recvtype, root, comm_ptr);
MPIR_ERR_CHECK(mpi_errno);
}

/* bcast to left */
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPI_BYTE, root, comm_ptr);
MPIR_ERR_CHECK(mpi_errno);
}
}

fn_exit:
MPIR_CHKLMEM_FREEALL();
return mpi_errno_ret;
return mpi_errno;
fn_fail:
mpi_errno_ret = mpi_errno;
goto fn_exit;
}
17 changes: 7 additions & 10 deletions src/mpi/coll/allgather/allgather_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
MPI_Aint sendcount,
MPI_Datatype sendtype,
void *recvbuf,
MPI_Aint recvcount,
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr)
{
int comm_size, rank;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
MPI_Aint recvtype_extent, recvtype_sz;
int pof2, src, rem;
void *tmp_buf = NULL;
Expand Down Expand Up @@ -67,8 +65,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
MPIR_ALLGATHER_TAG,
((char *) tmp_buf + curr_cnt * recvtype_sz),
curr_cnt * recvtype_sz, MPI_BYTE,
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE);
MPIR_ERR_CHECK(mpi_errno);
curr_cnt *= 2;
pof2 *= 2;
}
Expand All @@ -84,8 +82,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
dst, MPIR_ALLGATHER_TAG,
((char *) tmp_buf + curr_cnt * recvtype_sz),
rem * recvcount * recvtype_sz, MPI_BYTE,
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE);
MPIR_ERR_CHECK(mpi_errno);
}

/* Rotate blocks in tmp_buf down by (rank) blocks and store
Expand All @@ -101,13 +99,12 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
(comm_size - rank) * recvcount * recvtype_sz,
rank * recvcount * recvtype_sz, MPI_BYTE, recvbuf,
rank * recvcount, recvtype);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPIR_ERR_CHECK(mpi_errno);
}

fn_exit:
MPIR_CHKLMEM_FREEALL();
return mpi_errno_ret;
return mpi_errno;
fn_fail:
mpi_errno_ret = mpi_errno;
goto fn_exit;
}
15 changes: 6 additions & 9 deletions src/mpi/coll/allgather/allgather_intra_k_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
int
MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, int k,
MPIR_Errflag_t errflag)
MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, int k)
{
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
int i, j;
int nphases = 0;
int src, dst, p_of_k = 0; /* Largest power of k that is smaller than 'size' */
Expand Down Expand Up @@ -143,7 +141,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
mpi_errno = MPIC_Irecv((char *) tmp_recvbuf + j * recvcount * delta * recvtype_extent,
count, recvtype, src, MPIR_ALLGATHER_TAG, comm,
&reqs[num_reqs++]);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPIR_ERR_CHECK(mpi_errno);

MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST,
"Phase#%d:, k:%d Recv at:%p for count:%d", i,
Expand All @@ -154,16 +152,16 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
/* Send from the start of recv till `count` amount of data. */
mpi_errno =
MPIC_Isend(tmp_recvbuf, count, recvtype, dst, MPIR_ALLGATHER_TAG, comm,
&reqs[num_reqs++], errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
&reqs[num_reqs++]);
MPIR_ERR_CHECK(mpi_errno);

MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST,
"Phase#%d:, k:%d Send from:%p for count:%d",
i, k, tmp_recvbuf, count));

}
mpi_errno = MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPIR_ERR_CHECK(mpi_errno);
delta *= k;
}

Expand All @@ -189,8 +187,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
MPIR_CHKLMEM_FREEALL();

fn_exit:
return mpi_errno_ret;
return mpi_errno;
fn_fail:
mpi_errno_ret = mpi_errno;
goto fn_exit;
}
Loading