diff --git a/src/Infrastructure/stubs/mpiuni/mpi.c b/src/Infrastructure/stubs/mpiuni/mpi.c index 3903ed236a..7269183d4a 100644 --- a/src/Infrastructure/stubs/mpiuni/mpi.c +++ b/src/Infrastructure/stubs/mpiuni/mpi.c @@ -55,8 +55,28 @@ static int num_attr = 1,mpi_tag_ub = 100000000; /* To avoid problems with prototypes to the system memcpy() it is duplicated here + + This version also supports checking for MPI_IN_PLACE */ -int MPIUNI_Memcpy(void *a,const void* b,int n) { +int MPIUNI_Memcpy(void *a,const void* b,int n,enum CheckForMPIInPlace_Flag check_flag) { + switch(check_flag) { + case CHECK_FOR_MPI_IN_PLACE_NONE: + // No pre-check in this case; proceed to the actual memcpy + break; + case CHECK_FOR_MPI_IN_PLACE_SOURCE: + if (b == MPI_IN_PLACE) { + // If the source is MPI_IN_PLACE, do nothing + return 0; + } + break; + case CHECK_FOR_MPI_IN_PLACE_DEST: + if (a == MPI_IN_PLACE) { + // If the dest is MPI_IN_PLACE, do nothing + return 0; + } + break; + } + int i; char *aa= (char*)a; char *bb= (char*)b; @@ -403,7 +423,7 @@ void MPIUNI_STDCALL mpi_allreduce(void *sendbuf,void *recvbuf,int *count,int *da *ierr = MPI_ERR_OP; return; } - MPIUNI_Memcpy(recvbuf,sendbuf,(*count)*MPIUNI_DATASIZE[*datatype]); + MPIUNI_Memcpy(recvbuf,sendbuf,(*count)*MPIUNI_DATASIZE[*datatype],CHECK_FOR_MPI_IN_PLACE_SOURCE); *ierr = MPI_SUCCESS; } void MPIUNI_STDCALL mpi_allreduce_(void *sendbuf,void *recvbuf,int *count,int *datatype,int *op,int *comm,int *ierr) diff --git a/src/Infrastructure/stubs/mpiuni/mpi.h b/src/Infrastructure/stubs/mpiuni/mpi.h index d4580637a4..f1b22cdeeb 100644 --- a/src/Infrastructure/stubs/mpiuni/mpi.h +++ b/src/Infrastructure/stubs/mpiuni/mpi.h @@ -113,8 +113,13 @@ typedef int MPI_Info; /* handle */ #define MPI_INFO_NULL (0) #define MPI_IN_PLACE (void *)(-1) +enum CheckForMPIInPlace_Flag { + CHECK_FOR_MPI_IN_PLACE_NONE, + CHECK_FOR_MPI_IN_PLACE_SOURCE, + CHECK_FOR_MPI_IN_PLACE_DEST +}; -extern int MPIUNI_Memcpy(void*,const void*,int); +extern int MPIUNI_Memcpy(void*,const void*,int,enum CheckForMPIInPlace_Flag); /* In order to handle datatypes, we make them into "sizeof(raw-type)"; this allows us to do the MPIUNI_Memcpy's easily */ @@ -463,7 +468,7 @@ extern double ESMC_MPI_Wtime(void); dest,sendtag,recvbuf,recvcount,\ recvtype,source,recvtag,\ comm,status) \ - MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount) * (sendtype)) + MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount) * (sendtype),CHECK_FOR_MPI_IN_PLACE_NONE) #define MPI_Sendrecv_replace(buf,count, datatype,dest,sendtag,\ source,recvtag,comm,status) MPI_SUCCESS #define MPI_Type_contiguous(count, oldtype,newtype) \ @@ -520,7 +525,7 @@ extern double ESMC_MPI_Wtime(void); MPIUNI_TMP = (void*)(long) (root),\ MPIUNI_TMP = (void*)(long) (recvtype),\ MPIUNI_TMP = (void*)(long) (comm),\ - MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\ + MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPI_SUCCESS) #define MPI_Gatherv(sendbuf,sendcount, sendtype,\ recvbuf,recvcounts,displs,\ @@ -530,7 +535,7 @@ extern double ESMC_MPI_Wtime(void); MPIUNI_TMP = (void*)(long) (recvtype),\ MPIUNI_TMP = (void*)(long) (root),\ MPIUNI_TMP = (void*)(long) (comm),\ - MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\ + MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPI_SUCCESS) #define MPI_Scatter(sendbuf,sendcount, sendtype,\ recvbuf,recvcount, recvtype,\ @@ -560,7 +565,7 @@ extern double ESMC_MPI_Wtime(void); (MPIUNI_TMP = (void*)(long) (recvcount),\ MPIUNI_TMP = (void*)(long) (recvtype),\ MPIUNI_TMP = (void*)(long) (comm),\ - MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\ + MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPI_SUCCESS) #define MPI_Allgatherv(sendbuf,sendcount, sendtype,\ recvbuf,recvcounts,displs,recvtype,comm) \ @@ -568,7 +573,7 @@ extern double ESMC_MPI_Wtime(void); MPIUNI_TMP = (void*)(long) (displs),\ MPIUNI_TMP = (void*)(long) (recvtype),\ MPIUNI_TMP = (void*)(long) (comm),\ - MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype)),\ + MPIUNI_Memcpy(recvbuf,sendbuf,(sendcount)* (sendtype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPI_SUCCESS) #define MPI_Alltoall(sendbuf,sendcount, sendtype,\ recvbuf,recvcount, recvtype,\ @@ -581,13 +586,13 @@ extern double ESMC_MPI_Wtime(void); rdispls, recvtypes,comm) MPI_Abort(MPI_COMM_WORLD,0) #define MPI_Reduce(sendbuf, recvbuf,count,\ datatype,op,root,comm) \ - (MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype)),\ + (MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPIUNI_TMP = (void*)(long) (comm),MPI_SUCCESS) #define MPI_Allreduce(sendbuf, recvbuf,count,datatype,op,comm) \ - (MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype)),\ + (MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPIUNI_TMP = (void*)(long) (comm),MPI_SUCCESS) #define MPI_Scan(sendbuf, recvbuf,count,datatype,op,comm) \ - (MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype)),\ + (MPIUNI_Memcpy(recvbuf,sendbuf,(count)*(datatype),CHECK_FOR_MPI_IN_PLACE_SOURCE), \ MPIUNI_TMP = (void*)(long) (comm),MPI_SUCCESS) #define MPI_Reduce_scatter(sendbuf, recvbuf,recvcounts,\ datatype,op,comm) \ @@ -668,7 +673,7 @@ extern double ESMC_MPI_Wtime(void); #define MPI_Cart_map(comm,ndims,dims,periods,newrank) MPI_Abort(MPI_COMM_WORLD,0) #define MPI_Graph_map(comm,a,b,c,d) MPI_Abort(MPI_COMM_WORLD,0) #define MPI_Get_processor_name(name,result_len) \ - (MPIUNI_Memcpy(name,"localhost",9*sizeof(char)),name[10] = 0,*(result_len) = 10) + (MPIUNI_Memcpy(name,"localhost",9*sizeof(char),CHECK_FOR_MPI_IN_PLACE_NONE),name[10] = 0,*(result_len) = 10) #define MPI_Errhandler_create(function,errhandler) \ (MPIUNI_TMP = (void*)(long) (errhandler),\ MPI_SUCCESS)