adapt to ds3.2

This commit is contained in:
maxiao
2025-09-30 17:44:54 +08:00
parent 1237aa19ce
commit 8f7453e3af
9 changed files with 199 additions and 49 deletions

View File

@@ -165,10 +165,10 @@ DINLINE void start_sync(
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__hip_atomic_store(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
flag)
;
}
@@ -211,16 +211,16 @@ DINLINE void end_sync(
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
__hip_atomic_store(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
__HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(
while (__hip_atomic_load(
&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
__HIP_MEMORY_SCOPE_AGENT) < flag)
;
}
__syncthreads();

View File

@@ -21,6 +21,7 @@ limitations under the License.
#include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4
using Vec = int4;
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
int original = v;
#pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int n = __shfl_up_sync(mask, v, offset);
int n = __shfl_up(v, offset);
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
}
return v - original;

View File

@@ -60,7 +60,7 @@ template <typename T>
__device__ float convert_to_float(T x) {
if constexpr (std::is_same_v<T, __half>) {
return __half2float(x);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return __bfloat162float(x);
} else if constexpr (std::is_same_v<T, float>) {
return x;
@@ -575,8 +575,8 @@ void topk_softmax(
renormalize,
stream);
} else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),