Add moe topk softmax templated from vllm (#4302)
This commit is contained in:
@@ -173,6 +173,12 @@ void moe_align_block_size(
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer);
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
@@ -65,6 +65,15 @@ inline int getSMVersion() {
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
|
||||
#ifndef USE_ROCM
|
||||
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
|
||||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
|
||||
#else
|
||||
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
|
||||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
@@ -117,11 +126,11 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
|
||||
return max_value;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user