diff --git a/sgl-kernel/csrc/moe/cutlass_moe_helper.cu b/sgl-kernel/csrc/moe/cutlass_moe_helper.cu index e8af2093e..576ad233b 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe_helper.cu +++ b/sgl-kernel/csrc/moe/cutlass_moe_helper.cu @@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts( int* problem_sizes, int* problem_sizes_transpose, bool transpose = false) { - int expert_id = threadIdx.x; + int64_t expert_id = static_cast(threadIdx.x); if (expert_id >= gridDim.x * blockDim.x) { return; @@ -46,11 +46,11 @@ __global__ void get_group_gemm_starts( problem_sizes_transpose[expert_id * 3 + 2] = k; } - int32_t expert_offset = expert_offsets[expert_id]; - int a_stride = 0; - int b_stride = 0; - int a_scale_stride = 0; - int b_scale_stride = 0; + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t a_stride = 0; + int64_t b_stride = 0; + int64_t a_scale_stride = 0; + int64_t b_scale_stride = 0; if (!transpose) { a_stride = expert_offset * k; b_stride = expert_id * k * n;