fix: fix apply_shuffle_mul_sum (#7444)
This commit is contained in:
114
sgl-kernel/csrc/moe/prepare_moe_input.cu
Executable file → Normal file
114
sgl-kernel/csrc/moe/prepare_moe_input.cu
Executable file → Normal file
@@ -2,9 +2,11 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "utils.h"
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
@@ -255,37 +257,67 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void apply_shuffle_mul_sum_kernel(
|
||||
const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride]
|
||||
scalar_t* __restrict__ output_tensor, // [m, row_stride]
|
||||
const int32_t* __restrict__ permutation, // [m * topk]
|
||||
const scalar_t* __restrict__ input_tensor, // [m * topk, k] (expert-major layout)
|
||||
scalar_t* __restrict__ output_tensor, // [m, k] (token-major layout)
|
||||
const int32_t* __restrict__ permutation, // [m * topk] (c_map: token-major-idx -> expert-major-idx)
|
||||
int m,
|
||||
int topk,
|
||||
int row_stride,
|
||||
const scalar_t* __restrict__ factors) // [m * topk] or nullptr
|
||||
const scalar_t* __restrict__ factors) // [m * topk] (topk_weights, token-major layout)
|
||||
{
|
||||
int i = blockIdx.x; // [0, m * topk)
|
||||
int d = threadIdx.x; // [0, row_stride)
|
||||
|
||||
if (i >= m || d >= row_stride) return;
|
||||
|
||||
scalar_t sum_val = 0.0;
|
||||
|
||||
for (int j = 0; j < topk; ++j) {
|
||||
int index_2d = i * topk + j;
|
||||
int src_row = permutation[index_2d];
|
||||
if (src_row >= m) continue;
|
||||
|
||||
scalar_t val = input_tensor[src_row * row_stride + d];
|
||||
|
||||
scalar_t factor = 1.0;
|
||||
if (factors != nullptr) {
|
||||
factor = factors[index_2d];
|
||||
}
|
||||
|
||||
sum_val += factor * val;
|
||||
int i = blockIdx.x;
|
||||
if (i >= m) {
|
||||
return;
|
||||
}
|
||||
|
||||
output_tensor[i * row_stride + d] = sum_val;
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using t = float;
|
||||
using vec_t = flashinfer::vec_t<t, vec_size>;
|
||||
int thread_idx = threadIdx.x;
|
||||
int stride = blockDim.x;
|
||||
|
||||
for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) {
|
||||
int d = d_vec_idx * vec_size;
|
||||
vec_t sum_vec;
|
||||
sum_vec.fill(0.0f);
|
||||
|
||||
for (int j = 0; j < topk; ++j) {
|
||||
int token_major_idx = i * topk + j;
|
||||
int src_row = permutation[token_major_idx];
|
||||
|
||||
vec_t val_vec;
|
||||
val_vec.cast_load(input_tensor + src_row * row_stride + d);
|
||||
|
||||
t factor = 1.0;
|
||||
if (factors != nullptr) {
|
||||
factor = factors[token_major_idx];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < vec_size; ++k) {
|
||||
sum_vec[k] += factor * val_vec[k];
|
||||
}
|
||||
}
|
||||
sum_vec.cast_store(output_tensor + i * row_stride + d);
|
||||
}
|
||||
|
||||
// remainder part
|
||||
int remainder_start = (row_stride / vec_size) * vec_size;
|
||||
for (int d = remainder_start + thread_idx; d < row_stride; d += stride) {
|
||||
t sum_val = 0.0;
|
||||
for (int j = 0; j < topk; ++j) {
|
||||
int token_major_idx = i * topk + j;
|
||||
int src_row = permutation[token_major_idx];
|
||||
t val = input_tensor[src_row * row_stride + d];
|
||||
|
||||
t factor = 1.0;
|
||||
if (factors != nullptr) {
|
||||
factor = factors[token_major_idx];
|
||||
}
|
||||
sum_val += factor * val;
|
||||
}
|
||||
output_tensor[i * row_stride + d] = sum_val;
|
||||
}
|
||||
}
|
||||
|
||||
void get_apply_shuffle_mul_sum_caller(
|
||||
@@ -304,7 +336,11 @@ void get_apply_shuffle_mul_sum_caller(
|
||||
|
||||
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
|
||||
|
||||
dim3 block(std::min(256, row_stride));
|
||||
auto scalar_type = output_tensor.scalar_type();
|
||||
uint32_t vec_size = 16 / sizeof(scalar_type);
|
||||
auto blockDim = std::min(row_stride / vec_size, 1024U);
|
||||
dim3 block(blockDim);
|
||||
|
||||
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
|
||||
|
||||
@@ -317,29 +353,17 @@ void get_apply_shuffle_mul_sum_caller(
|
||||
factors_ptr = factors_opt->data_ptr();
|
||||
}
|
||||
|
||||
if (output_tensor.scalar_type() == at::ScalarType::Half) {
|
||||
const at::Half* factor_data = static_cast<const at::Half*>(factors_ptr);
|
||||
apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0, stream>>>(
|
||||
input_tensor.data_ptr<at::Half>(),
|
||||
output_tensor.data_ptr<at::Half>(),
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(output_tensor.scalar_type(), scalar_t, [&] {
|
||||
apply_shuffle_mul_sum_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input_tensor.data_ptr()),
|
||||
static_cast<scalar_t*>(output_tensor.data_ptr()),
|
||||
perm_ptr,
|
||||
m,
|
||||
topk,
|
||||
row_stride,
|
||||
static_cast<const at::Half*>(factors_ptr));
|
||||
} else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) {
|
||||
const c10::BFloat16* factor_data = static_cast<const c10::BFloat16*>(factors_ptr);
|
||||
apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0, stream>>>(
|
||||
input_tensor.data_ptr<c10::BFloat16>(),
|
||||
output_tensor.data_ptr<c10::BFloat16>(),
|
||||
perm_ptr,
|
||||
m,
|
||||
topk,
|
||||
row_stride,
|
||||
static_cast<const c10::BFloat16*>(factors_ptr));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type());
|
||||
}
|
||||
static_cast<const scalar_t*>(factors_ptr));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user