diff --git a/sgl-kernel/csrc/moe/prepare_moe_input.cu b/sgl-kernel/csrc/moe/prepare_moe_input.cu old mode 100755 new mode 100644 index 0eeec4c75..46fff0649 --- a/sgl-kernel/csrc/moe/prepare_moe_input.cu +++ b/sgl-kernel/csrc/moe/prepare_moe_input.cu @@ -2,9 +2,11 @@ #include #include +#include #include #include "cutlass/array.h" +#include "utils.h" constexpr uint64_t THREADS_PER_EXPERT = 512; @@ -255,37 +257,67 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr template __global__ void apply_shuffle_mul_sum_kernel( - const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride] - scalar_t* __restrict__ output_tensor, // [m, row_stride] - const int32_t* __restrict__ permutation, // [m * topk] + const scalar_t* __restrict__ input_tensor, // [m * topk, k] (expert-major layout) + scalar_t* __restrict__ output_tensor, // [m, k] (token-major layout) + const int32_t* __restrict__ permutation, // [m * topk] (c_map: token-major-idx -> expert-major-idx) int m, int topk, int row_stride, - const scalar_t* __restrict__ factors) // [m * topk] or nullptr + const scalar_t* __restrict__ factors) // [m * topk] (topk_weights, token-major layout) { - int i = blockIdx.x; // [0, m * topk) - int d = threadIdx.x; // [0, row_stride) - - if (i >= m || d >= row_stride) return; - - scalar_t sum_val = 0.0; - - for (int j = 0; j < topk; ++j) { - int index_2d = i * topk + j; - int src_row = permutation[index_2d]; - if (src_row >= m) continue; - - scalar_t val = input_tensor[src_row * row_stride + d]; - - scalar_t factor = 1.0; - if (factors != nullptr) { - factor = factors[index_2d]; - } - - sum_val += factor * val; + int i = blockIdx.x; + if (i >= m) { + return; } - output_tensor[i * row_stride + d] = sum_val; + constexpr uint32_t vec_size = 16 / sizeof(scalar_t); + using t = float; + using vec_t = flashinfer::vec_t; + int thread_idx = threadIdx.x; + int stride = blockDim.x; + + for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) { + int d = d_vec_idx * vec_size; + vec_t sum_vec; + sum_vec.fill(0.0f); + + for (int j = 0; j < topk; ++j) { + int token_major_idx = i * topk + j; + int src_row = permutation[token_major_idx]; + + vec_t val_vec; + val_vec.cast_load(input_tensor + src_row * row_stride + d); + + t factor = 1.0; + if (factors != nullptr) { + factor = factors[token_major_idx]; + } + +#pragma unroll + for (int k = 0; k < vec_size; ++k) { + sum_vec[k] += factor * val_vec[k]; + } + } + sum_vec.cast_store(output_tensor + i * row_stride + d); + } + + // remainder part + int remainder_start = (row_stride / vec_size) * vec_size; + for (int d = remainder_start + thread_idx; d < row_stride; d += stride) { + t sum_val = 0.0; + for (int j = 0; j < topk; ++j) { + int token_major_idx = i * topk + j; + int src_row = permutation[token_major_idx]; + t val = input_tensor[src_row * row_stride + d]; + + t factor = 1.0; + if (factors != nullptr) { + factor = factors[token_major_idx]; + } + sum_val += factor * val; + } + output_tensor[i * row_stride + d] = sum_val; + } } void get_apply_shuffle_mul_sum_caller( @@ -304,7 +336,11 @@ void get_apply_shuffle_mul_sum_caller( TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk"); - dim3 block(std::min(256, row_stride)); + auto scalar_type = output_tensor.scalar_type(); + uint32_t vec_size = 16 / sizeof(scalar_type); + auto blockDim = std::min(row_stride / vec_size, 1024U); + dim3 block(blockDim); + dim3 grid(m); // blockIdx.x = j, blockIdx.y = i auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index()); @@ -317,29 +353,17 @@ void get_apply_shuffle_mul_sum_caller( factors_ptr = factors_opt->data_ptr(); } - if (output_tensor.scalar_type() == at::ScalarType::Half) { - const at::Half* factor_data = static_cast(factors_ptr); - apply_shuffle_mul_sum_kernel<<>>( - input_tensor.data_ptr(), - output_tensor.data_ptr(), + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(output_tensor.scalar_type(), scalar_t, [&] { + apply_shuffle_mul_sum_kernel<<>>( + static_cast(input_tensor.data_ptr()), + static_cast(output_tensor.data_ptr()), perm_ptr, m, topk, row_stride, - static_cast(factors_ptr)); - } else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) { - const c10::BFloat16* factor_data = static_cast(factors_ptr); - apply_shuffle_mul_sum_kernel<<>>( - input_tensor.data_ptr(), - output_tensor.data_ptr(), - perm_ptr, - m, - topk, - row_stride, - static_cast(factors_ptr)); - } else { - TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type()); - } + static_cast(factors_ptr)); + return true; + }); } /**