Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916)

Co-authored-by: Elfie Guo <elfiegxf@gmail.com>
This commit is contained in:
Elfie Guo
2025-06-07 15:24:39 -07:00
committed by GitHub
parent 62fec60d81
commit 3e56f557fd
7 changed files with 146 additions and 12 deletions

114
sgl-kernel/csrc/moe/prepare_moe_input.cu Normal file → Executable file
View File

@@ -252,3 +252,117 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
return;
}
template <typename scalar_t>
__global__ void apply_shuffle_mul_sum_kernel(
const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride]
scalar_t* __restrict__ output_tensor, // [m, row_stride]
const int32_t* __restrict__ permutation, // [m * topk]
int m,
int topk,
int row_stride,
const scalar_t* __restrict__ factors) // [m * topk] or nullptr
{
int i = blockIdx.x; // [0, m * topk)
int d = threadIdx.x; // [0, row_stride)
if (i >= m || d >= row_stride) return;
scalar_t sum_val = 0.0;
for (int j = 0; j < topk; ++j) {
int index_2d = i * topk + j;
int src_row = permutation[index_2d];
if (src_row >= m) continue;
scalar_t val = input_tensor[src_row * row_stride + d];
scalar_t factor = 1.0;
if (factors != nullptr) {
factor = factors[index_2d];
}
sum_val += factor * val;
}
output_tensor[i * row_stride + d] = sum_val;
}
void get_apply_shuffle_mul_sum_caller(
const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16
torch::Tensor& output_tensor, // [m, row_stride], bf16/f16
const torch::Tensor& permutation, // [m * topk], int32
const std::optional<torch::Tensor>& factors_opt) // optional [m * topk], bf16/f16
{
TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");
int m = output_tensor.size(0);
int topk = int(permutation.size(0) / m);
int row_stride = output_tensor.size(1);
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
dim3 block(std::min(256, row_stride));
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
const int32_t* perm_ptr = permutation.data_ptr<int32_t>();
void* factors_ptr = nullptr;
if (factors_opt.has_value()) {
TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
factors_ptr = factors_opt->data_ptr();
}
if (output_tensor.scalar_type() == at::ScalarType::Half) {
const at::Half* factor_data = static_cast<const at::Half*>(factors_ptr);
apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0, stream>>>(
input_tensor.data_ptr<at::Half>(),
output_tensor.data_ptr<at::Half>(),
perm_ptr,
m,
topk,
row_stride,
static_cast<const at::Half*>(factors_ptr));
} else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) {
const c10::BFloat16* factor_data = static_cast<const c10::BFloat16*>(factors_ptr);
apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0, stream>>>(
input_tensor.data_ptr<c10::BFloat16>(),
output_tensor.data_ptr<c10::BFloat16>(),
perm_ptr,
m,
topk,
row_stride,
static_cast<const c10::BFloat16*>(factors_ptr));
} else {
TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type());
}
}
/**
* @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
*
* This function performs the equivalent of the following PyTorch expression:
*
* (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
*
* Specifically:
* - `input` is shuffled using the `permutation` tensor.
* - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
* - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
*
* @param input Input tensor of shape (m * topk, k), representing c2.
* @param output Output tensor of shape (m, k), where the final reduced results are stored.
* @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
* @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
*/
void apply_shuffle_mul_sum(
const torch::Tensor& input,
torch::Tensor& output,
const torch::Tensor& permutation,
const std::optional<torch::Tensor>& factors) {
get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
}