[Fix]Fix index oob in get_group_gemm_starts kernel. (#8564)

This commit is contained in:
Qi Yuhang
2025-07-31 10:49:35 +08:00
committed by GitHub
parent 66a398f49d
commit 9b9e82539b

View File

@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
int* problem_sizes,
int* problem_sizes_transpose,
bool transpose = false) {
int expert_id = threadIdx.x;
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
if (expert_id >= gridDim.x * blockDim.x) {
return;
@@ -46,11 +46,11 @@ __global__ void get_group_gemm_starts(
problem_sizes_transpose[expert_id * 3 + 2] = k;
}
int32_t expert_offset = expert_offsets[expert_id];
int a_stride = 0;
int b_stride = 0;
int a_scale_stride = 0;
int b_scale_stride = 0;
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t a_stride = 0;
int64_t b_stride = 0;
int64_t a_scale_stride = 0;
int64_t b_scale_stride = 0;
if (!transpose) {
a_stride = expert_offset * k;
b_stride = expert_id * k * n;