From 90f44b74e6c27e695834b3dcb1ab0f83709c9e08 Mon Sep 17 00:00:00 2001 From: SijiaYang Date: Tue, 12 Aug 2025 04:41:19 +0800 Subject: [PATCH] fix: w4afp8 accuracy problem and rebase (#8752) Signed-off-by: yangsijia.614 Co-authored-by: Jinwu --- .../sglang/srt/layers/moe/cutlass_w4a8_moe.py | 9 ++-- .../sglang/srt/layers/moe/ep_moe/kernels.py | 43 +++++++++++++++++++ .../sglang/srt/layers/quantization/w4afp8.py | 29 ++++++++----- 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 0a2b44bd1..7a03511c4 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -11,7 +11,7 @@ from sgl_kernel import ( ) from sglang.srt.layers.moe.ep_moe.kernels import ( - post_reorder_triton_kernel, + post_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe, run_cutlass_moe_ep_preproess, ) @@ -199,14 +199,13 @@ def cutlass_w4a8_moe( ) output = torch.empty_like(a) - post_reorder_triton_kernel[(m,)]( + post_reorder_triton_kernel_for_cutlass_moe[(m,)]( c2, output, src2dst, - topk_ids_, + local_topk_ids, topk_weights, - start_expert_id, - end_expert_id, + num_experts, topk, k, 0, diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index d3ec90a7c..f1649d5c9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -581,6 +581,49 @@ def post_reorder_triton_kernel( ) +@triton.jit +def post_reorder_triton_kernel_for_cutlass_moe( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + num_experts, + topk, + hidden_size, + dst_start, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + store_ptr = output_ptr + src_idx * hidden_size + + vec = tl.arange(0, BLOCK_SIZE) + + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size + + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id != num_experts: + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) + dst_idx = dst_idx - dst_start + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + @triton.jit def compute_m_range( pid, diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index ba11a4b6e..7a471870a 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + assert "weight_loader" in extra_weight_attrs # Fused gate_up_proj (column parallel) @@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, @@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): def apply( self, layer: EPMoE, - hidden_states: torch.Tensor, + x: torch.Tensor, topk_output: TopKOutput, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + routed_scaling_factor: Optional[float] = None, **kwargs, ) -> torch.Tensor: @@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, _ = topk_output local_topk_ids = topk_ids - if layer.expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where( - layer.expert_map[topk_ids] != layer.num_experts, - layer.expert_map[topk_ids], - layer.num_experts, - ) + local_topk_ids = torch.where( + topk_ids == -1, + layer.num_experts, + topk_ids, + ) - return cutlass_w4a8_moe( + output = cutlass_w4a8_moe( layer.start_expert_id, layer.end_expert_id, layer.num_experts, - hidden_states, + x, layer.w13_weight, layer.w2_weight, layer.w13_weight_scale_inv, @@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale, layer.w2_input_scale, ) + if routed_scaling_factor is not None: + output *= routed_scaling_factor + return output