[Fix]Fix index oob in get_group_gemm_starts kernel. (#8564)
This commit is contained in:
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
|
|||||||
int* problem_sizes,
|
int* problem_sizes,
|
||||||
int* problem_sizes_transpose,
|
int* problem_sizes_transpose,
|
||||||
bool transpose = false) {
|
bool transpose = false) {
|
||||||
int expert_id = threadIdx.x;
|
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
|
||||||
|
|
||||||
if (expert_id >= gridDim.x * blockDim.x) {
|
if (expert_id >= gridDim.x * blockDim.x) {
|
||||||
return;
|
return;
|
||||||
@@ -46,11 +46,11 @@ __global__ void get_group_gemm_starts(
|
|||||||
problem_sizes_transpose[expert_id * 3 + 2] = k;
|
problem_sizes_transpose[expert_id * 3 + 2] = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t expert_offset = expert_offsets[expert_id];
|
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||||
int a_stride = 0;
|
int64_t a_stride = 0;
|
||||||
int b_stride = 0;
|
int64_t b_stride = 0;
|
||||||
int a_scale_stride = 0;
|
int64_t a_scale_stride = 0;
|
||||||
int b_scale_stride = 0;
|
int64_t b_scale_stride = 0;
|
||||||
if (!transpose) {
|
if (!transpose) {
|
||||||
a_stride = expert_offset * k;
|
a_stride = expert_offset * k;
|
||||||
b_stride = expert_id * k * n;
|
b_stride = expert_id * k * n;
|
||||||
|
|||||||
Reference in New Issue
Block a user