Fix bug of deepseek-v3 under DP+EP mode with large batchsize/seqlen (#6449)
This commit is contained in:
@@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
"""
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
g_id = tl.program_id(0)
|
||||
y_ptr += g_id * group_size
|
||||
y_q_ptr += g_id * group_size
|
||||
y_ptr += g_id.to(tl.int64) * group_size
|
||||
y_q_ptr += g_id.to(tl.int64) * group_size
|
||||
|
||||
# Convert g_id the flattened block coordinate to 2D so we can index
|
||||
# into the output y_scales matrix
|
||||
|
||||
@@ -35,12 +35,12 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
const int scale_num_rows = 0,
|
||||
const int scale_stride = 0) {
|
||||
const int threads_per_group = 16;
|
||||
const int local_group_id = threadIdx.x / threads_per_group;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int block_group_id = blockIdx.x * groups_per_block;
|
||||
const int global_group_id = block_group_id + local_group_id;
|
||||
const int block_group_offset = global_group_id * group_size;
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user