适配v0.5.4

This commit is contained in:
maxiao
2025-10-25 12:16:25 +08:00
parent 1053e1be17
commit 251235c229
8 changed files with 213 additions and 49 deletions

View File

@@ -21,6 +21,7 @@ limitations under the License.
#include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4
using Vec = int4;
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
int original = v;
#pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int n = __shfl_up_sync(mask, v, offset);
int n = __shfl_up(v, offset);
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
}
return v - original;

View File

@@ -60,7 +60,7 @@ template <typename T>
__device__ float convert_to_float(T x) {
if constexpr (std::is_same_v<T, __half>) {
return __half2float(x);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return __bfloat162float(x);
} else if constexpr (std::is_same_v<T, float>) {
return x;
@@ -575,8 +575,8 @@ void topk_softmax(
renormalize,
stream);
} else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),