adapt to ds3.2

This commit is contained in:
maxiao
2025-09-30 17:44:54 +08:00
parent 1237aa19ce
commit 8f7453e3af
9 changed files with 199 additions and 49 deletions

View File

@@ -165,10 +165,10 @@ DINLINE void start_sync(
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__hip_atomic_store(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
flag)
;
}
@@ -211,16 +211,16 @@ DINLINE void end_sync(
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
__hip_atomic_store(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
__HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(
while (__hip_atomic_load(
&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
__HIP_MEMORY_SCOPE_AGENT) < flag)
;
}
__syncthreads();