adapt to ds3.2
This commit is contained in:
@@ -60,7 +60,7 @@ template <typename T>
|
||||
__device__ float convert_to_float(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return __half2float(x);
|
||||
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
return x;
|
||||
@@ -575,8 +575,8 @@ void topk_softmax(
|
||||
renormalize,
|
||||
stream);
|
||||
} else if (dtype == at::ScalarType::BFloat16) {
|
||||
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
|
||||
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
|
||||
Reference in New Issue
Block a user