From 143ec5f36c61cbd8ab96d953437a16075e442166 Mon Sep 17 00:00:00 2001 From: lizhigong <306128847@qq.com> Date: Tue, 21 Oct 2025 16:27:31 +0800 Subject: [PATCH] adaptation w4A8 quantization (cherry picked from commit 848c5b8290ac896431f6843c77c1a8341e1cdb46) --- python/sglang/srt/_custom_ops.py | 31 ++++ .../srt/layers/quantization/slimquant_w4a8.py | 7 + .../quantization/slimquant_w4a8_marlin.py | 140 ++++++++++++------ 3 files changed, 131 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index de47707c1..cf63dd6c8 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -5,6 +5,15 @@ from typing import List, Optional, Tuple import torch from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu +try: + from lmslim import quant_ops + from lmslim import quant_tools +except Exception: + print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") +try: + import lightop +except Exception: + print("INFO: Please install lightop if you want to infer awq of marlin.\n") logger = logging.getLogger(__name__) use_vllm_custom_allreduce = get_bool_env_var( @@ -175,3 +184,25 @@ def mscclpp_allreduce( context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int ) -> None: return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks) + +def triton_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + best_config:Optional[list] = None) -> torch.Tensor: + + return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config) + +def triton_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.float16, + device: str = "cuda:0", + best_config:Optional[list] = None, + repeat:Optional[int] = 2): + return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat) \ No newline at end of file diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8.py b/python/sglang/srt/layers/quantization/slimquant_w4a8.py index 485424014..c34ee6f02 100644 --- a/python/sglang/srt/layers/quantization/slimquant_w4a8.py +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8.py @@ -16,6 +16,7 @@ from lmslim.layers.gemm.int8_utils import ( per_token_quant_int8) from sglang.srt import _custom_ops as ops from vllm.utils import W8a8GetCacheJSON +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig import os @@ -343,6 +344,12 @@ class SlimQuantW4A8Int8MoEMethod: layer.w2_weight_scale.data, requires_grad=False ) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py index 1452615a8..0d3303380 100644 --- a/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py @@ -1,4 +1,6 @@ from typing import Any, Callable, Dict, List, Optional +from sglang.srt.layers.moe.token_dispatcher.base import CombineInput +from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput import torch from sglang.srt import _custom_ops as ops from sglang.srt.utils import set_weight_attrs @@ -9,6 +11,7 @@ from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase) from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig try: from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin @@ -146,13 +149,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod: layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) tp_size = get_tensor_model_parallel_world_size() - + intermediate_size = intermediate_size_per_partition # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -205,51 +208,28 @@ class SlimQuantW4A8Int8MarlinMoEMethod: layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False) layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - use_nn_moe: Optional[bool] = False, - routed_scaling_factor: Optional[float] = None, - use_fused_gate: Optional[bool] = False, - **_ - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") - # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - routed_scaling_factor=routed_scaling_factor, - use_fused_gate=use_fused_gate + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + + topk_weights, topk_ids, _ = topk_output + x, topk_weights = apply_topk_weights_cpu( + self.moe_runner_config.apply_router_weight_on_input, topk_weights, x ) workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() - return fused_experts_impl_w4a8_marlin( + output = fused_experts_impl_w4a8_marlin( x, layer.w13_weight, layer.w2_weight, @@ -260,13 +240,79 @@ class SlimQuantW4A8Int8MarlinMoEMethod: inplace=True, use_int4_w4a8=True, per_channel_quant=True, - activation=activation, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, + activation=layer.moe_runner_config.activation, + expert_map=layer.expert_map_gpu, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + global_num_experts=layer.moe_runner_config.num_experts, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - use_nn_moe=use_nn_moe, + use_nn_moe=False, ) + return StandardCombineInput(hidden_states=output) + # def _apply( + # self, + # layer: torch.nn.Module, + # x: torch.Tensor, + # router_logits: torch.Tensor, + # top_k: int, + # #renormalize: bool, + # #use_grouped_topk: bool = False, + # topk_group: Optional[int] = None, + # num_expert_group: Optional[int] = None, + # global_num_experts: int = -1, + # expert_map: Optional[torch.Tensor] = None, + # custom_routing_function: Optional[Callable] = None, + # scoring_func: str = "softmax", + # e_score_correction_bias: Optional[torch.Tensor] = None, + # apply_router_weight_on_input: bool = False, + # activation: str = "silu", + # enable_eplb: bool = False, + # use_nn_moe: Optional[bool] = False, + # routed_scaling_factor: Optional[float] = None, + # use_fused_gate: Optional[bool] = False, + # **_ + # ) -> torch.Tensor: + # from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + # from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + # if enable_eplb: + # raise NotImplementedError( + # "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") + # # Expert selection + # topk_weights, topk_ids = FusedMoE.select_experts( + # hidden_states=x, + # router_logits=router_logits, + # #use_grouped_topk=use_grouped_topk, + # top_k=top_k, + # #renormalize=renormalize, + # topk_group=topk_group, + # num_expert_group=num_expert_group, + # custom_routing_function=custom_routing_function, + # scoring_func=scoring_func, + # e_score_correction_bias=e_score_correction_bias, + # routed_scaling_factor=routed_scaling_factor, + # use_fused_gate=use_fused_gate + # ) + # workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() + # return fused_experts_impl_w4a8_marlin( + # x, + # layer.w13_weight, + # layer.w2_weight, + # topk_weights=topk_weights, + # topk_ids=topk_ids, + # workspace=workspace, + # global_reduce_buffer=global_reduce_buffer, + # inplace=True, + # use_int4_w4a8=True, + # per_channel_quant=True, + # activation=activation, + # expert_map=expert_map, + # apply_router_weight_on_input=apply_router_weight_on_input, + # global_num_experts=global_num_experts, + # w1_scale=(layer.w13_weight_scale), + # w2_scale=(layer.w2_weight_scale), + # a1_scale=layer.w13_input_scale, + # a2_scale=layer.w2_input_scale, + # use_nn_moe=use_nn_moe, + # )