From 1b2e8f76d9ed08733d8dde22cb64f72410cb2262 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Sat, 24 May 2025 03:39:18 +0800 Subject: [PATCH] [2/2] Support Qserve (#6521) --- python/sglang/srt/configs/model_config.py | 3 + .../srt/layers/quantization/__init__.py | 2 + .../srt/layers/quantization/int8_kernel.py | 23 +- python/sglang/srt/layers/quantization/qoq.py | 244 ++++++++++++++++++ python/sglang/srt/server_args.py | 1 + 5 files changed, 268 insertions(+), 5 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/qoq.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 982cae8dd..bc0fe0cb1 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -349,6 +349,7 @@ class ModelConfig: "w8a8_int8", "w8a8_fp8", "moe_wna16", + "qoq", ] compatible_quantization_methods = { "modelopt_fp4": ["modelopt"], @@ -458,6 +459,8 @@ def _get_and_verify_dtype( # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) + if isinstance(config_dtype, str): + config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None) if config_dtype is None: config_dtype = torch.float32 diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index cf8f405dd..e1c053bfa 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -67,6 +67,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp8Config, ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config +from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -80,6 +81,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "w8a8_fp8": W8A8Fp8Config, "moe_wna16": MoeWNA16Config, "compressed-tensors": CompressedTensorsConfig, + "qoq": QoQConfig, } # VLLM-dependent quantization methods diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index 32367a5bf..7c6c3dbd4 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -22,9 +22,11 @@ def _per_token_quant_int8( x_ptr, xq_ptr, scale_ptr, + x_sum_ptr, stride_x, stride_xq, N, + CAL_SUM: tl.constexpr, BLOCK: tl.constexpr, ): # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 @@ -38,16 +40,23 @@ def _per_token_quant_int8( scale_x = absmax / 127 x_q = x * (127 / absmax) x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + if CAL_SUM: + x_sum = tl.sum(x, axis=0) + tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty)) tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) - tl.store(scale_ptr + row_id, scale_x) + tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty)) -def per_token_quant_int8(x): +def per_token_quant_int8(x, scale_dtype=torch.float32, cal_sum=False): M = x.numel() // x.shape[-1] N = x.shape[-1] x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) - scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) + scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype) + if cal_sum: + x_sum = torch.empty(x.shape[:-1], device=x.device, dtype=x.dtype) + else: + x_sum = None BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) @@ -57,15 +66,19 @@ def per_token_quant_int8(x): x, x_q, scales, + x_sum, stride_x=x.stride(-2), stride_xq=x_q.stride(-2), N=N, + CAL_SUM=cal_sum, BLOCK=BLOCK, num_warps=num_warps, num_stages=1, ) - - return x_q, scales + if cal_sum: + return x_q, scales, x_sum + else: + return x_q, scales @triton.jit diff --git a/python/sglang/srt/layers/quantization/qoq.py b/python/sglang/srt/layers/quantization/qoq.py new file mode 100644 index 000000000..3e3a3dfb6 --- /dev/null +++ b/python/sglang/srt/layers/quantization/qoq.py @@ -0,0 +1,244 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sgl_kernel import qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm + + +QoQ_SUPPORTED_WEIGHT_BITS = [4] +QoQ_SUPPORTED_GROUP_SIZES = [-1, 128] + + +class QoQConfig(QuantizationConfig): + """Config class for QoQ Quantization. + + - Weight: static, per-channel/group, asymmetric + - Activation: dynamic, per-token, symmetric + + Reference: https://arxiv.org/abs/2405.04532 + https://github.com/mit-han-lab/omniserve + """ + + def __init__(self, weight_bits: int, group_size: int) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + # Verify + if self.weight_bits not in QoQ_SUPPORTED_WEIGHT_BITS: + raise ValueError( + f"QoQ does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {QoQ_SUPPORTED_WEIGHT_BITS} " + "are supported." + ) + if self.group_size not in QoQ_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"QoQ does not support group_size = {self.group_size}. " + f"Only group_sizes = {QoQ_SUPPORTED_GROUP_SIZES} " + "are supported." + ) + + # 4 bits packed into 8 bit datatype. + self.pack_factor = 8 // self.weight_bits + + def __repr__(self) -> str: + return "QoQConfig(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size + ) + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(self) -> str: + return "qoq" + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QoQConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return QoQLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class QoQLinearMethod(LinearMethodBase): + """Linear method for QoQ. + + Args: + quant_config: The QoQ quantization config. + """ + + def __init__(self, quant_config: QoQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % 32 != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by 32." + ) + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}." + ) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}." + ) + + qweight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("qweight", qweight) + + s1_scales = ChannelQuantScaleParameter( + data=torch.empty(output_size_per_partition, dtype=torch.float16), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("s1_scales", s1_scales) + + if self.quant_config.group_size == -1: + s1_szeros = ChannelQuantScaleParameter( + data=torch.empty(output_size_per_partition, dtype=torch.float16), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("s1_szeros", s1_szeros) + else: + s2_scales = GroupQuantScaleParameter( + data=torch.empty( + ( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + ), + dtype=torch.int8, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("s2_scales", s2_scales) + + s2_zeros = GroupQuantScaleParameter( + data=torch.empty( + ( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + ), + dtype=torch.int8, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("s2_zeros", s2_zeros) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.s1_scales = Parameter(layer.s1_scales.data, requires_grad=False) + if self.quant_config.group_size == -1: + layer.s1_szeros = Parameter(layer.s1_szeros.data, requires_grad=False) + else: + layer.s2_scales = Parameter(layer.s2_scales.data, requires_grad=False) + layer.s2_zeros = Parameter(layer.s2_zeros.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + assert x.dtype == torch.float16, "QoQ only supports float16 input now" + if self.quant_config.group_size == -1: + x_q, x_scale, x_sum = per_token_quant_int8( + x, scale_dtype=x.dtype, cal_sum=True + ) + out = qserve_w4a8_per_chn_gemm( + x_q, layer.qweight, layer.s1_scales, x_scale, layer.s1_szeros, x_sum + ) + else: + x_q, x_scale = per_token_quant_int8(x, scale_dtype=x.dtype) + out = qserve_w4a8_per_group_gemm( + x_q, + layer.qweight, + layer.s2_zeros, + layer.s2_scales, + layer.s1_scales, + x_scale, + ) + if bias is not None: + out = out + bias + return out diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9402128df..40b41b036 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -577,6 +577,7 @@ class ServerArgs: "w8a8_int8", "w8a8_fp8", "moe_wna16", + "qoq", ], help="The quantization method.", )