[Fix]Fix index oob in get_group_gemm_starts kernel. (#8564)
This commit is contained in:
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
|
||||
int* problem_sizes,
|
||||
int* problem_sizes_transpose,
|
||||
bool transpose = false) {
|
||||
int expert_id = threadIdx.x;
|
||||
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
@@ -46,11 +46,11 @@ __global__ void get_group_gemm_starts(
|
||||
problem_sizes_transpose[expert_id * 3 + 2] = k;
|
||||
}
|
||||
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
int a_stride = 0;
|
||||
int b_stride = 0;
|
||||
int a_scale_stride = 0;
|
||||
int b_scale_stride = 0;
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||
int64_t a_stride = 0;
|
||||
int64_t b_stride = 0;
|
||||
int64_t a_scale_stride = 0;
|
||||
int64_t b_scale_stride = 0;
|
||||
if (!transpose) {
|
||||
a_stride = expert_offset * k;
|
||||
b_stride = expert_id * k * n;
|
||||
|
||||
Reference in New Issue
Block a user