Skip to content

Commit

Permalink
gfx11: don't use LL for sendrecv (#853) (#854)
Browse files Browse the repository at this point in the history
* gfx11: don't use LL for sendrecv (#853)

* gfx11: don't use LL for sendrecv

* Use builtin instead of inline asm

(cherry picked from commit f70e3e5)

* Use relaxed atomics for LL on GFX11 (#859)

(cherry picked from commit 6a0a6a3)
  • Loading branch information
wenkaidu authored Sep 22, 2023
1 parent fafbd40 commit 0491ffb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/collectives/device/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
} else { \
const int w = threadIdx.x/WARP_SIZE; \
const int wid = threadIdx.x%WARP_SIZE; \
__threadfence(); \
if (wid == 0) { \
barrier_next[w] += nthreads/WARP_SIZE; \
atomicAdd((unsigned long long *)barriers, 1); \
Expand Down
44 changes: 28 additions & 16 deletions src/collectives/device/prims_ll.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
#include "npkit/npkit.h"
#endif

#ifdef __GFX11__
#define LL_STORE(SRC, DST) \
__atomic_store_n((DST), (SRC), __ATOMIC_RELAXED)
#define LL_LOAD(SRC) \
__atomic_load_n(SRC, __ATOMIC_RELAXED)
#else
#define LL_STORE(SRC, DST) \
__builtin_nontemporal_store((SRC), (DST))
#define LL_LOAD(SRC) \
__builtin_nontemporal_load(SRC)
#endif

template<typename T, typename RedOp, typename Fan, int Direct, int P2p>
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>> {
Expand Down Expand Up @@ -151,8 +163,8 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
union ncclLLFifoLine i4;
do {
i4.v[0] = __builtin_nontemporal_load(src->v);
i4.v[1] = __builtin_nontemporal_load(src->v+1);
i4.v[0] = LL_LOAD(src->v);
i4.v[1] = LL_LOAD(src->v+1);
#if defined(ENABLE_NPKIT) && (defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_ENTRY) && defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_EXIT) || defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME))
npkitWaitRecvSpins++;
#endif
Expand Down Expand Up @@ -187,8 +199,8 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
if (i < fan.nrecv()) {
union ncclLLFifoLine* src = recvPtr(i) + offset;
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
line[i].v[0] = __builtin_nontemporal_load(src->v);
line[i].v[1] = __builtin_nontemporal_load(src->v+1);
line[i].v[0] = LL_LOAD(src->v);
line[i].v[1] = LL_LOAD(src->v+1);
#else
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4));
#endif
Expand All @@ -209,8 +221,8 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:

do {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
line[i].v[0] = __builtin_nontemporal_load(src->v);
line[i].v[1] = __builtin_nontemporal_load(src->v+1);
line[i].v[0] = LL_LOAD(src->v);
line[i].v[1] = LL_LOAD(src->v+1);
#else
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4));
#endif
Expand Down Expand Up @@ -238,8 +250,8 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
i4.flag1 = flag;
i4.data2 = (val >> 32);
i4.flag2 = flag;
__builtin_nontemporal_store(i4.v[0], dst->v);
__builtin_nontemporal_store(i4.v[1], dst->v+1);
LL_STORE(i4.v[0], dst->v);
LL_STORE(i4.v[1], dst->v+1);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
#endif
Expand All @@ -258,13 +270,13 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
};
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
if(sizeof(U) == 1)
u1 = __builtin_nontemporal_load((uint8_t*)src);
u1 = LL_LOAD((uint8_t*)src);
else if(sizeof(U) == 2)
u2 = __builtin_nontemporal_load((uint16_t*)src);
u2 = LL_LOAD((uint16_t*)src);
else if(sizeof(U) == 4)
u4 = __builtin_nontemporal_load((uint32_t*)src);
u4 = LL_LOAD((uint32_t*)src);
else
u8 = __builtin_nontemporal_load((uint64_t*)src);
u8 = LL_LOAD((uint64_t*)src);
#else
if(sizeof(U) == 1)
asm("ld.volatile.global.b8 %0,[%1];" : "=r"(u4) : "l"(src));
Expand All @@ -290,13 +302,13 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
elt = val;
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
if(sizeof(U) == 1)
__builtin_nontemporal_store(u1, (uint8_t*)dst);
LL_STORE(u1, (uint8_t*)dst);
else if(sizeof(U) == 2)
__builtin_nontemporal_store(u2, (uint16_t*)dst);
LL_STORE(u2, (uint16_t*)dst);
else if(sizeof(U) == 4)
__builtin_nontemporal_store(u4, (uint32_t*)dst);
LL_STORE(u4, (uint32_t*)dst);
else
__builtin_nontemporal_store(u8, (uint64_t*)dst);
LL_STORE(u8, (uint64_t*)dst);
#else
if(sizeof(U) == 1)
asm("st.volatile.global.b8 [%0],%1;" :: "l"(dst), "r"(u4));
Expand Down
13 changes: 7 additions & 6 deletions src/collectives/device/prims_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,18 @@ class Primitives<

template<int Recv, int Send>
inline __device__ void postPeer(bool dataStored) {
if (Send && (flags & RolePostSend) && dataStored)
#ifdef __GFX9__
__builtin_amdgcn_buffer_wbinvl1();
#else
__threadfence_system();
#endif

if ((flags & Send*RolePostSend) && next_hdp_reg)
STORE((unsigned int *)next_hdp_reg, 0x1);

if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
step += StepPerSlice;
if (Send && (flags & RolePostSend) && dataStored)
#ifdef __GFX9__
__asm__ __volatile__("buffer_wbinvl1_vol");
#else
__threadfence_system();
#endif
STORE(connStepPtr, step);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/graph/tuning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom

for (int c=0; c<NCCL_NUM_FUNCTIONS; c++) for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
// Disable LL protocol on gfx11xx
int pEnable = (p == NCCL_PROTO_LL && comm->topo->nodes[GPU].nodes[0].gpu.gcn/100 == 11) ? 0 : protoEnable[p];
int pEnable = protoEnable[p];
if (pEnable == 2 && p == NCCL_PROTO_LL128) {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
#if defined(ENABLE_LL128)
Expand Down

0 comments on commit 0491ffb

Please sign in to comment.