From b80ae5e9ffada9f4609927fe8c60a55bdee0fadc Mon Sep 17 00:00:00 2001 From: maxiao1 Date: Sat, 25 Oct 2025 16:33:07 +0800 Subject: [PATCH] adaptation w4a8 tp --- python/sglang/srt/layers/layernorm.py | 4 ++-- .../srt/layers/quantization/slimquant_w4a8_marlin.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 728e71e5c..34d6eb55a 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -170,9 +170,7 @@ class RMSNorm(CustomOp): output = torch.empty_like(x) residual_out = torch.empty_like(x) fused_add_rms_norm( - output, x, - residual_out, residual, self.weight.data, self.variance_epsilon, @@ -180,7 +178,9 @@ class RMSNorm(CustomOp): return output, residual_out except TypeError: fused_add_rms_norm( + output, x, + residual_out, residual, self.weight.data, self.variance_epsilon, diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py index 0d3303380..50d29a8f3 100644 --- a/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, List, Optional -from sglang.srt.layers.moe.token_dispatcher.base import CombineInput -from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput +# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput + import torch from sglang.srt import _custom_ops as ops from sglang.srt.utils import set_weight_attrs @@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: def apply( self, layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: + dispatch_output, + ) : + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output from sglang.srt.layers.moe.topk import apply_topk_weights_cpu @@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: use_int4_w4a8=True, per_channel_quant=True, activation=layer.moe_runner_config.activation, - expert_map=layer.expert_map_gpu, + # expert_map=layer.expert_map_gpu, apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, global_num_experts=layer.moe_runner_config.num_experts, w1_scale=(layer.w13_weight_scale),