diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 2f92223de..01bdf226c 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -813,14 +813,17 @@ def _fwd_kernel_ep_scatter_2( offset_in = tl.arange(0, HIDDEN_SIZE_PAD) mask = offset_in < HIDDEN_SIZE - offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) - mask_s = offset_in_s < SCALE_HIDDEN_SIZE + index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) + mask_s = index_in_s < SCALE_HIDDEN_SIZE for token_id_int32 in range(start_token_id, total_token_num, grid_num): token_id = token_id_int32.to(tl.int64) to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) to_copy_s = tl.load( - recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s + recv_x_scale + + token_id * recv_x_scale_stride0 + + index_in_s * recv_x_scale_stride1, + mask=mask_s, ) for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4): @@ -841,7 +844,11 @@ def _fwd_kernel_ep_scatter_2( output_tensor_scale + dest_token_index * output_tensor_scale_stride0 ) tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) - tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) + tl.store( + output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1, + to_copy_s, + mask=mask_s, + ) # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py @@ -856,6 +863,7 @@ def ep_scatter( output_tensor_scale: torch.Tensor, m_indices: torch.Tensor, output_index: torch.Tensor, + scale_ue8m0: bool = False, ): BLOCK_E = 128 # token num of per expert is aligned to 128 BLOCK_D = 128 # block size of quantization @@ -865,7 +873,15 @@ def ep_scatter( # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) grid = num_experts + scale_hidden_size = hidden_size // BLOCK_D + if scale_ue8m0: + # ue8m0 scales are packed here (4 scales per int32), + # hence the effective size of this dimension is divided by 4. + scale_hidden_size = ceil_div(scale_hidden_size, 4) + assert m_indices.shape[0] % BLOCK_E == 0 + assert recv_x_scale.dtype == output_tensor_scale.dtype + assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size _fwd_kernel_ep_scatter_1[(grid,)]( num_recv_tokens_per_expert, @@ -904,8 +920,8 @@ def ep_scatter( num_warps=num_warps, HIDDEN_SIZE=hidden_size, HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), - SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, - SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), + SCALE_HIDDEN_SIZE=scale_hidden_size, + SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size), ) return