From 67510e017237fc53769b627cac4d82e0bc9daf5e Mon Sep 17 00:00:00 2001 From: lizhigong <306128847@qq.com> Date: Tue, 21 Oct 2025 14:06:43 +0800 Subject: [PATCH] adaptation part w4A8 quantization (cherry picked from commit 68277eac30f16dbd332455527f6b9a874c22b66d) --- python/sglang/srt/configs/model_config.py | 2 + .../srt/layers/quantization/__init__.py | 2 + .../srt/layers/quantization/slimquant_w4a8.py | 408 ++++++++++++++++++ .../quantization/slimquant_w4a8_marlin.py | 272 ++++++++++++ .../srt/layers/quantization/w4a8_utils.py | 92 ++++ python/sglang/srt/server_args.py | 1 + 6 files changed, 777 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/slimquant_w4a8.py create mode 100644 python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py create mode 100644 python/sglang/srt/layers/quantization/w4a8_utils.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 71b420d50..3985d0350 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -614,6 +614,7 @@ class ModelConfig: "petit_nvfp4", "quark", "mxfp4", + "slimquant_w4a8_marlin", ] optimized_quantization_methods = [ "fp8", @@ -633,6 +634,7 @@ class ModelConfig: "qoq", "w4afp8", "petit_nvfp4", + "slimquant_w4a8_marlin", ] compatible_quantization_methods = { "modelopt_fp4": ["modelopt"], diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index df0658f86..2e6a06a04 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.utils import is_cuda, is_hip, mxfp_supported _is_mxfp_supported = mxfp_supported() @@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "w4afp8": W4AFp8Config, "petit_nvfp4": PetitNvFp4Config, "fbgemm_fp8": FBGEMMFp8Config, + "slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, } diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8.py b/python/sglang/srt/layers/quantization/slimquant_w4a8.py new file mode 100644 index 000000000..485424014 --- /dev/null +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8.py @@ -0,0 +1,408 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +from sglang.srt.layers.linear import set_weight_attrs +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from torch.nn.parameter import Parameter +from sglang.srt.layers.linear import LinearBase +from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + _ColumnvLLMParameter, + RowvLLMParameter, +) +from lmslim.layers.gemm.int8_utils import ( + per_token_group_quant_int8, + per_token_quant_int8) +from sglang.srt import _custom_ops as ops +from vllm.utils import W8a8GetCacheJSON + +import os + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + pass + +W8A8_TRITONJSON=W8a8GetCacheJSON() + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + scales= scale_a* scale_b.T + gemmout= torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32)) + output = (scales *gemmout).to(out_dtype) + if bias is not None: + output = output + bias + return output.to(out_dtype) + + +class SlimQuantW4A8Int8Config(QuantizationConfig): + """Config class for W8A8 Int8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "slimquant_w4a8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if isinstance(layer, LinearBase): + return SlimQuantW4A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return SlimQuantW4A8Int8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: SlimQuantW4A8Int8Config): + self.quantization_config = quantization_config + self.tritonsingleton= W8a8GetCacheJSON() + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + n=layer.weight.shape[0] + k=layer.weight.shape[1] + + if self.w8a8_strategy==1: + if {n,k} not in self.tritonsingleton.weight_shapes: + self.tritonsingleton.weight_shapes.append({n,k}) + json_file=self.tritonsingleton.get_w8a8json_name(n,k) + configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k) + + if configs_dict: + self.tritonsingleton.triton_json_dict.update(configs_dict) + + for key, value in configs_dict.items(): + m=int(key.split('_')[0]) + ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) + else: + weight_data=layer.weight.data + _weight=weight_data.T.contiguous().reshape(n,-1) + layer.weight.data=_weight + + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_quant_args: Optional[list[torch.Tensor]] = None, + silu_quant_args: Optional[list[torch.Tensor]] = None + ): + # if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: + # assert len(input_quant_args) == 2 + # x_q, x_scale = input_quant_args + # elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None: + # x_q, x_scale = silu_quant_args + # else: + x_q, x_scale = per_token_quant_int8(x) + + if self.w8a8_strategy==1: + m=x_q.shape[0] + k=x_q.shape[1] + n=layer.weight.shape[1] + + if len(W8A8_TRITONJSON.triton_json_dict)==0: + best_config=None + + elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict: + if m<=16: + m_=m + elif m<=64: + m_= (m + 3) & -4 #取值到最近的4的倍数 + elif m<=160: + m_=(m + 7) & -8 + + elif m<200: #256 + m_=160 + elif m<480: #512 + m_=256 + elif m<960: #1024 + m_=512 + elif m<2048: + m_=1024 + elif m<4096: + m_=2048 + elif m<6000: + m_=4096 + else: + m_=8192 + + best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"] + + else: + best_config=None + + #if best_config==None: + # print("m:{},n:{},k:{}".format(m,n,k)) + # print("config not found!") + + return ops.triton_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias,best_config=best_config) + elif self.w8a8_strategy==2: + return ops.cutlass_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias) + else: + return ops.rocblas_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias) + + +class SlimQuantW4A8Int8MoEMethod: + """MoE method for W4A8INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + self.tritonsingleton= W8a8GetCacheJSON() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + tp_size = get_tensor_model_parallel_world_size() + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E=layer.w13_weight.shape[0] + N1=layer.w13_weight.shape[1] + N2=layer.w2_weight.shape[1] + K=N1//2 + if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: + self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) + + TOPK= self.tritonsingleton.topk + + json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True) + configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) + + #warmup + if configs_dict: + self.tritonsingleton.triton_moejson_dict.update(configs_dict) + + layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **_ + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.") + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a8=True, + per_channel_quant=True, + activation=activation, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=use_nn_moe, + ) diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py new file mode 100644 index 000000000..1452615a8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py @@ -0,0 +1,272 @@ +from typing import Any, Callable, Dict, List, Optional +import torch +from sglang.srt import _custom_ops as ops +from sglang.srt.utils import set_weight_attrs +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from torch.nn.parameter import Parameter +from sglang.srt.layers.linear import LinearBase +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl +from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase) +from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod + +try: + from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin +except Exception: + print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") + + +class MarlinMoeWorkspace: + """ + Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE. + global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device + """ + _instances = {} + def __new__(cls, device): + if device not in cls._instances: + instance = super().__new__(cls) + instance._initialized = False + cls._instances[device] = instance + return cls._instances[device] + + def __init__(self, device): + if self._initialized: + return + sms = torch.cuda.get_device_properties(device).multi_processor_count + self.workspace = torch.zeros( + 500, dtype=torch.int, device=device, requires_grad=False + ) + self.global_reduce_buffer = torch.zeros( + sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False + ) + self._initialized = True + + def get_buffers(self): + return self.workspace, self.global_reduce_buffer + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + scales= scale_a* scale_b.T + gemmout= torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32)) + output = (scales *gemmout).to(out_dtype) + if bias is not None: + output = output + bias + return output.to(out_dtype) + + +class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): + """Config class for W4A8 Int8 Quantization. + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "slimquant_w4a8_marlin" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig": + return cls() + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[str]: + if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \ + and user_quant == "slimquant_w4a8_marlin": + return cls.get_name() + return None + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if isinstance(layer, LinearBase): + return SlimQuantW4A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return SlimQuantW4A8Int8MarlinMoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SlimQuantW4A8Int8MarlinMoEMethod: + """MoE method for W4A8INT8 Marlin. + Supports loading INT8 checkpoints with static weight scale and + dynamic/static activation scale. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + tp_size = get_tensor_model_parallel_world_size() + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False) + layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **_ + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate + ) + workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() + return fused_experts_impl_w4a8_marlin( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + workspace=workspace, + global_reduce_buffer=global_reduce_buffer, + inplace=True, + use_int4_w4a8=True, + per_channel_quant=True, + activation=activation, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=use_nn_moe, + ) diff --git a/python/sglang/srt/layers/quantization/w4a8_utils.py b/python/sglang/srt/layers/quantization/w4a8_utils.py new file mode 100644 index 000000000..a05652e38 --- /dev/null +++ b/python/sglang/srt/layers/quantization/w4a8_utils.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + +try: + from lightop import awq_marlin_repack_w4a8 + use_lightop = False +except Exception: + use_lightop = False + +def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor: + """ + 将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。 + 每个int8包含两个int4,分别提取到int32的低4位,其余位为0。 + + Args: + tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。 + + Returns: + torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。 + """ + if tensor_int8.dtype != torch.int8: + raise ValueError("Input tensor must be of type torch.int8") + + N, K_half = tensor_int8.shape + tensor_uint8 = tensor_int8.to(torch.uint8) + high4 = tensor_uint8 & 0x0F + low4 = (tensor_uint8 >> 4) & 0x0F + unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device) + unpacked[:, 0::2] = low4.to(torch.int32) + unpacked[:, 1::2] = high4.to(torch.int32) + + return unpacked + +def get_weight_perms(interleave: bool=True): + perm = [] + for i in range(64): + + for col in range(4): + cur_col = (i % 16) * 4 + col + for row in range(8): + cur_row = (i // 16) * 8 + row + cur_idx = cur_row * 64 + cur_col + perm.append(cur_idx) + + perm = np.array(perm) + if interleave: + interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + + perm = torch.from_numpy(perm) + + return perm + +def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8): + size_k, size_n = q_w.shape + q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // k_tile, size_n * k_tile)) + q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape) + + orig_device = q_w.device + q_w = q_w.contiguous().to(torch.int32) + M, N = q_w.shape + assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})" + q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device) + for i in range(pack_factor): + q_packed += q_w[:, i::pack_factor] << (4 * i) + + return q_packed + +def w4a8_2_marlin_weight(w4a8_w): + full_w4a8_w = unpack_int8_to_int4(w4a8_w) + full_w4a8_w = full_w4a8_w.T + weight_perm = get_weight_perms() + marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8) + return marlin_q_w + +def w4a8_weight_repack_impl(input): + if use_lightop: + size_batch = input.shape[0] + size_n = input.shape[1] + size_k = input.shape[2] * 2 + output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32) + awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n) + else: + w_marlin_list = [] + for e in range(input.shape[0]): + w_marlin_in = w4a8_2_marlin_weight(input[e]) + w_marlin_list.append(w_marlin_in) + output = torch.stack(w_marlin_list, dim=0) + + return output diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7935de6f6..7d26160a6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [ "w4afp8", "mxfp4", "compressed-tensors", # for Ktransformers + "slimquant_w4a8_marlin", ] ATTENTION_BACKEND_CHOICES = [