adapt to ds3.2
This commit is contained in:
@@ -165,10 +165,10 @@ DINLINE void start_sync(
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
|
||||
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
|
||||
flag)
|
||||
;
|
||||
}
|
||||
@@ -211,16 +211,16 @@ DINLINE void end_sync(
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
__HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(
|
||||
while (__hip_atomic_load(
|
||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__MEMORY_SCOPE_DEVICE) < flag)
|
||||
__HIP_MEMORY_SCOPE_AGENT) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define WARP_SIZE 64
|
||||
#define VEC_SIZE 4
|
||||
using Vec = int4;
|
||||
|
||||
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
|
||||
int original = v;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
int n = __shfl_up_sync(mask, v, offset);
|
||||
int n = __shfl_up(v, offset);
|
||||
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
|
||||
}
|
||||
return v - original;
|
||||
|
||||
@@ -60,7 +60,7 @@ template <typename T>
|
||||
__device__ float convert_to_float(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return __half2float(x);
|
||||
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
return x;
|
||||
@@ -575,8 +575,8 @@ void topk_softmax(
|
||||
renormalize,
|
||||
stream);
|
||||
} else if (dtype == at::ScalarType::BFloat16) {
|
||||
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
|
||||
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
|
||||
Reference in New Issue
Block a user