Skip to content

Commit

Permalink
mpiuni: implement check for MPI_IN_PLACE
Browse files Browse the repository at this point in the history
According to MPI documentation, most MPI routines check sendbuf, but
some (MPI_Scatter, MPI_Scatterv and maybe others; note that these are
not yet implemented yet in mpiuni) check recvbuf, and some (e.g.,
MPI_Sendrecv_replace) don't check for MPI_IN_PLACE at all. To make
mpiuni respect the MPI standard in this respect, I have added an
argument to MPIUNI_Memcpy saying whether to check the source (sendbuf),
dest (recvbuf) or neither for equality with MPI_IN_PLACE.

(We could probably get away with keeping things simpler by always
checking both a and b for equality with MPI_IN_PLACE, but following the
MPI standard in this respect seems marginally safer.)
  • Loading branch information
billsacks committed Oct 20, 2023
1 parent 2242fce commit e33eade
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
24 changes: 22 additions & 2 deletions src/Infrastructure/stubs/mpiuni/mpi.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,28 @@ static int num_attr = 1,mpi_tag_ub = 100000000;

/*
To avoid problems with prototypes to the system memcpy() it is duplicated here
This version also supports checking for MPI_IN_PLACE
*/
int MPIUNI_Memcpy(void *a,const void* b,int n) {
int MPIUNI_Memcpy(void *a,const void* b,int n,enum CheckForMPIInPlace_Flag check_flag) {
switch(check_flag) {
case CHECK_FOR_MPI_IN_PLACE_NONE:
// No pre-check in this case; proceed to the actual memcpy
break;
case CHECK_FOR_MPI_IN_PLACE_SOURCE:
if (b == MPI_IN_PLACE) {
// If the source is MPI_IN_PLACE, do nothing
return 0;
}
break;
case CHECK_FOR_MPI_IN_PLACE_DEST:
if (a == MPI_IN_PLACE) {
// If the dest is MPI_IN_PLACE, do nothing
return 0;
}
break;
}

int i;
char *aa= (char*)a;
char *bb= (char*)b;
Expand Down Expand Up @@ -403,7 +423,7 @@ void MPIUNI_STDCALL mpi_allreduce(void *sendbuf,void *recvbuf,int *count,int *da
*ierr = MPI_ERR_OP;
return;
}
MPIUNI_Memcpy(recvbuf,sendbuf,(*count)*MPIUNI_DATASIZE[*datatype]);
MPIUNI_Memcpy(recvbuf,sendbuf,(*count)*MPIUNI_DATASIZE[*datatype],CHECK_FOR_MPI_IN_PLACE_SOURCE);
*ierr = MPI_SUCCESS;
}
void MPIUNI_STDCALL mpi_allreduce_(void *sendbuf,void *recvbuf,int *count,int *datatype,int *op,int *comm,int *ierr)
Expand Down
25 changes: 15 additions & 10 deletions src/Infrastructure/stubs/mpiuni/mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,13 @@ typedef int MPI_Info; /* handle */
#define MPI_INFO_NULL (0)

#define MPI_IN_PLACE (void *)(-1)
enum CheckForMPIInPlace_Flag {
CHECK_FOR_MPI_IN_PLACE_NONE,
CHECK_FOR_MPI_IN_PLACE_SOURCE,
CHECK_FOR_MPI_IN_PLACE_DEST
};

extern int MPIUNI_Memcpy(void*,const void*,int);
extern int MPIUNI_Memcpy(void*,const void*,int,enum CheckForMPIInPlace_Flag);

/* In order to handle datatypes, we make them into "sizeof(raw-type)";
this allows us to do the MPIUNI_Memcpy's easily */
Expand Down Expand Up @@ -463,7 +468,7 @@ extern double ESMC_MPI_Wtime(void);
dest,sendtag,recvbuf,recvcount,\
recvtype,source,recvtag,\
comm,status) \
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount) * (sendtype))
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount) * (sendtype),CHECK_FOR_MPI_IN_PLACE_NONE)
#define MPI_Sendrecv_replace(buf,count, datatype,dest,sendtag,\
source,recvtag,comm,status) MPI_SUCCESS
#define MPI_Type_contiguous(count, oldtype,newtype) \
Expand Down Expand Up @@ -520,7 +525,7 @@ extern double ESMC_MPI_Wtime(void);
MPIUNI_TMP = (void*)(long) (root),\
MPIUNI_TMP = (void*)(long) (recvtype),\
MPIUNI_TMP = (void*)(long) (comm),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPI_SUCCESS)
#define MPI_Gatherv(sendbuf,sendcount, sendtype,\
recvbuf,recvcounts,displs,\
Expand All @@ -530,7 +535,7 @@ extern double ESMC_MPI_Wtime(void);
MPIUNI_TMP = (void*)(long) (recvtype),\
MPIUNI_TMP = (void*)(long) (root),\
MPIUNI_TMP = (void*)(long) (comm),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPI_SUCCESS)
#define MPI_Scatter(sendbuf,sendcount, sendtype,\
recvbuf,recvcount, recvtype,\
Expand Down Expand Up @@ -560,15 +565,15 @@ extern double ESMC_MPI_Wtime(void);
(MPIUNI_TMP = (void*)(long) (recvcount),\
MPIUNI_TMP = (void*)(long) (recvtype),\
MPIUNI_TMP = (void*)(long) (comm),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPI_SUCCESS)
#define MPI_Allgatherv(sendbuf,sendcount, sendtype,\
recvbuf,recvcounts,displs,recvtype,comm) \
(MPIUNI_TMP = (void*)(long) (recvcounts),\
MPIUNI_TMP = (void*)(long) (displs),\
MPIUNI_TMP = (void*)(long) (recvtype),\
MPIUNI_TMP = (void*)(long) (comm),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\
MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPI_SUCCESS)
#define MPI_Alltoall(sendbuf,sendcount, sendtype,\
recvbuf,recvcount, recvtype,\
Expand All @@ -581,13 +586,13 @@ extern double ESMC_MPI_Wtime(void);
rdispls, recvtypes,comm) MPI_Abort(MPI_COMM_WORLD,0)
#define MPI_Reduce(sendbuf, recvbuf,count,\
datatype,op,root,comm) \
(MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype)),\
(MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPIUNI_TMP = (void*)(long) (comm),MPI_SUCCESS)
#define MPI_Allreduce(sendbuf, recvbuf,count,datatype,op,comm) \
(MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype)),\
(MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPIUNI_TMP = (void*)(long) (comm),MPI_SUCCESS)
#define MPI_Scan(sendbuf, recvbuf,count,datatype,op,comm) \
(MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype)),\
(MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \
MPIUNI_TMP = (void*)(long) (comm),MPI_SUCCESS)
#define MPI_Reduce_scatter(sendbuf, recvbuf,recvcounts,\
datatype,op,comm) \
Expand Down Expand Up @@ -668,7 +673,7 @@ extern double ESMC_MPI_Wtime(void);
#define MPI_Cart_map(comm,ndims,dims,periods,newrank) MPI_Abort(MPI_COMM_WORLD,0)
#define MPI_Graph_map(comm,a,b,c,d) MPI_Abort(MPI_COMM_WORLD,0)
#define MPI_Get_processor_name(name,result_len) \
(MPIUNI_Memcpy(name,"localhost",9*sizeof(char)),name[10] = 0,*(result_len) = 10)
(MPIUNI_Memcpy(name,"localhost",9*sizeof(char),CHECK_FOR_MPI_IN_PLACE_NONE),name[10] = 0,*(result_len) = 10)
#define MPI_Errhandler_create(function,errhandler) \
(MPIUNI_TMP = (void*)(long) (errhandler),\
MPI_SUCCESS)
Expand Down

0 comments on commit e33eade

Please sign in to comment.