Add moe topk softmax templated from vllm (#4302)

This commit is contained in:
Qingquan Song
2025-03-14 12:03:33 -07:00
committed by GitHub
parent 660305c38a
commit 61e4433caf
9 changed files with 716 additions and 6 deletions

View File

@@ -173,6 +173,12 @@ void moe_align_block_size(
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer);
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
/*
* From csrc/speculative
*/

View File

@@ -65,6 +65,15 @@ inline int getSMVersion() {
return sm_major * 10 + sm_minor;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
@@ -117,11 +126,11 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
return max_value;
}