diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index ebac6f41a..7fd380f91 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1 ### Example: Serving with 8 A100/A800 with AWQ Quantization +**Recommended Usage** + Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance. One example is as follows: @@ -185,6 +187,13 @@ One example is as follows: python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16 ``` +Alternatively, you can use `--quantization awq_marlin` as follows: + +```bash +python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16 +``` + +Note that `awq_marlin` only supports `float16` now, which may lead to some precision loss. ### Example: Serving with 16 A100/A800 with int8 Quantization diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index e0f436343..9995b72d0 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -7,10 +7,6 @@ import torch try: from vllm.model_executor.layers.quantization.aqlm import AQLMConfig - from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, - AWQMoEMethod, - ) from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( CompressedTensorsW8A8Fp8MoEMethod, @@ -36,14 +32,14 @@ except ImportError: def override_quantization_method(self, *args, **kwargs): return None - AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = ( - DeepSpeedFPConfig - ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = ( - MarlinConfig - ) = QQQConfig = Int8TpuConfig = DummyConfig + AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( + ExpertsInt8Config + ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = ( + Int8TpuConfig + ) = DummyConfig -from sglang.srt.layers.quantization.awq import AWQConfig +from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( @@ -63,10 +59,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.qoq import QoQConfig -from sglang.srt.layers.quantization.utils import ( - get_dynamic_override, - get_linear_quant_method, -) +from sglang.srt.layers.quantization.utils import get_linear_quant_method from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -237,7 +230,6 @@ def monkey_patch_quant_configs(): setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) - monkey_patch_moe_apply(AWQMoEMethod) monkey_patch_moe_apply(GPTQMarlinMoEMethod) monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 6265f2217..453267383 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -2,21 +2,52 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +import warnings +from typing import Any, Callable, Dict, List, Optional import torch +from sglang.srt.layers.linear import LinearBase, set_weight_attrs from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, LinearMethodBase, QuantizationConfig, + QuantizeMethodBase, ) +from sglang.srt.layers.quantization.marlin_utils import ( + apply_awq_marlin_linear, + awq_to_marlin_zero_points, + check_marlin_supported, + check_marlin_supports_layer, + check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_moe_permute_scales, + marlin_permute_scales, + moe_awq_to_marlin_zero_points, + verify_marlin_supported, + verify_marlin_supports_shape, +) +from sglang.srt.layers.quantization.scalar_type import scalar_types from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.quantization.utils import replace_parameter + +try: + from vllm import _custom_ops as ops + + warnings.warn( + f"Using kernels directly from vllm. This might lead to performance degradation or " + f"missing functionalities as certain kernels may not be optimized. " + ) +except ImportError: + ops = None + from sglang.srt.utils import is_cuda _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import awq_dequantize + from sgl_kernel import awq_dequantize, fused_marlin_moe logger = logging.getLogger(__name__) @@ -103,6 +134,176 @@ class AWQConfig(QuantizationConfig): return None +class AWQMarlinConfig(QuantizationConfig): + """Config class for AWQ Marlin""" + + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any], + ) -> None: + super().__init__() + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.zero_point = zero_point + self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config + + if self.weight_bits not in self.TYPE_MAP: + raise ValueError( + f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}" + ) + + self.quant_type = self.TYPE_MAP[self.weight_bits] + + verify_marlin_supported( + self.quant_type, group_size=self.group_size, has_zp=self.zero_point + ) + + def __repr__(self) -> str: + return ( + f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_name(cls) -> str: + return "awq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig: + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + weight_bits, + group_size, + zero_point, + lm_head_quantized, + modules_to_not_convert, + config, + ) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" + ) + + if can_convert and is_valid_user_quant: + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info( + "Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference" + ) + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + # Check if the layer is supported by AWQMarlin. + if not check_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 + prefix, + ) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config + + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels." + ) + return MoeWNA16Config.from_config(self.full_config).get_quant_method( + layer, prefix + ) + return AWQMoEMethod(self) + return None + + @classmethod + def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + zero_point = quant_config.get("zero_point") + + if not _is_cuda: + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or group_size is None or zero_point is None: + return False + + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported( + quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point + ) + + class AWQLinearMethod(LinearMethodBase): """Linear method for AWQ. @@ -204,3 +405,382 @@ class AWQLinearMethod(LinearMethodBase): if bias is not None: out.add_(bias) return out.reshape(out_shape) + + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ Marlin. + + Args: + quant_config: The AWQ Marlin quantization config. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size, + ) + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + # TODO: Update this docs + # Checkpoints are serialized in AutoAWQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace(device) + + # Repack weights from AWQ format to marlin format. + marlin_qweight = ops.awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "scales", marlin_scales) + + # Permute zero-points from AWQ format to marlin format. + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "qzeros", marlin_zp) + + # Not-used + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_awq_marlin_linear( + input=x, + weight=layer.qweight, + weight_scale=layer.scales, + weight_zp=layer.qzeros, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias, + ) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + if self.quant_config.weight_bits != 4: + raise ValueError("AWQMoEMethod only supports 4bit now.") + self.quant_type = scalar_types.uint4 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + extra_weight_attrs.update( + { + "is_transposed": True, + "quant_method": FusedMoeWeightScaleSupported.GROUP.value, + } + ) + + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + device = layer.w13_qweight.device + layer.workspace = marlin_make_workspace(device, 4) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # hidden_size->intermediate_size + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.topk import select_experts + + assert activation == "silu", "Only SiLU activation is supported." + assert ( + scoring_func == "softmax" + ), "Only softmax score func is supported for now." + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ).to(orig_dtype) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 51d70255d..89e0eb84a 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -11,7 +11,7 @@ import numpy import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant -from sglang.srt.layers.quantization.scalar_type import ScalarType +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu if TYPE_CHECKING: @@ -247,6 +247,36 @@ def get_pack_factor(num_bits): return 32 // num_bits +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + def pack_cols( q_w: torch.Tensor, num_bits: int, @@ -399,3 +429,56 @@ def quantize_weights( w_s if group_size is not None else None, maybe_w_zp, ) + + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bb1efde29..12aa9cb39 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -355,6 +355,7 @@ class DeepseekV2MoE(nn.Module): self.shared_experts.gate_up_proj.quant_method, "quant_config" ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in { "awq", + "awq_marlin", "moe_wna16", } self.shared_experts_is_int8 = ( @@ -929,7 +930,7 @@ class DeepseekV2AttentionMLA(nn.Module): has_fused_proj and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config") and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name() - in {"awq", "moe_wna16"} + in {"awq", "awq_marlin", "moe_wna16"} ) self.use_min_latency_fused_a_gemm = ( has_fused_proj @@ -2551,6 +2552,7 @@ class DeepseekV2ForCausalLM(nn.Module): cat_dim = 0 if self.quant_config is not None and ( self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" or self.quant_config.get_name() == "moe_wna16" ): cat_dim = 1 diff --git a/python/sglang/test/test_marlin_moe.py b/python/sglang/test/test_marlin_moe.py new file mode 100644 index 000000000..e5b4c986a --- /dev/null +++ b/python/sglang/test/test_marlin_moe.py @@ -0,0 +1,286 @@ +import types +from typing import Optional + +import pytest +import torch +from sgl_kernel import fused_marlin_moe + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types +from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize + + +def stack_and_dev(tensors: list[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def torch_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + apply_router_weights_on_input: bool = False, +) -> torch.Tensor: + assert ( + global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None and global_num_experts == expert_map.shape[0]) + ) + + M, K = a.shape + topk = topk_ids.shape[1] + print("quant_dtype", quant_dtype) + # exit(0) + if apply_router_weights_on_input: + assert topk == 1 + a = a * topk_weight.to(a.dtype) + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + num_experts = w1.shape[0] + + topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] + + f32 = torch.float32 + + for i in range(num_experts): + mask = topk_ids == i + if mask.sum(): + if quant_dtype is None: + tmp1 = a[mask] @ w1[i].transpose(0, 1) + tmp2 = SiluAndMul()(tmp1) + out[mask] = tmp2 @ w2[i].transpose(0, 1) + + if apply_router_weights_on_input: + return out + else: + return ( + (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1)) + .sum(dim=1) + .to(out.dtype) + ) + + +def torch_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + return torch_experts( + a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map + ) + + +def marlin_moe_generate_valid_test_cases(): + import itertools + + m_list = [1, 123, 666] + n_list = [128, 1024] + k_list = [256, 2048] + e_list = [4, 12] + topk_list = [2, 3] + dtype_list = [torch.half, torch.bfloat16] + group_size_list = [128] + act_order_list = [True, False] + quant_type_list = [ + scalar_types.uint4, + scalar_types.uint4b8, + ] + is_k_full_list = [True, False] + + all_combinations = itertools.product( + m_list, + n_list, + k_list, + e_list, + topk_list, + dtype_list, + group_size_list, + act_order_list, + quant_type_list, + is_k_full_list, + ) + + def is_invalid( + m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full + ): + + # Filter act_order + if act_order: + if group_size in (-1, k, n): + return False + if quant_type not in [scalar_types.uint4b8]: + return False + elif not is_k_full: + return False + + return True + + cases = [] + for case in all_combinations: + if is_invalid(*case): + cases.append(case) + return cases + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.parametrize( + ("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases(), +) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + group_size: int, + act_order: bool, + quant_type: ScalarType, + is_k_full: bool, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + torch.manual_seed(0) + + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + if has_zp: + return + else: + if not is_k_full: + return + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 + + e_map = None + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + zeros1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + if has_zp: + w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size + ) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + zeros1_l.append(zeros1) + else: + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + zeros2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + if has_zp: + w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size + ) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + zeros2_l.append(zeros2) + else: + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + + score = torch.randn((m, e), device="cuda", dtype=dtype) + from sglang.srt.layers.moe.topk import fused_topk_torch_native + + topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False) + + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) + + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + num_bits=4, + is_k_full=is_k_full, + ) + + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) + + +if __name__ == "__main__": + # Run the specific test function directly + pytest.main([__file__]) diff --git a/python/sglang/test/test_marlin_utils.py b/python/sglang/test/test_marlin_utils.py new file mode 100644 index 000000000..920cb7d8b --- /dev/null +++ b/python/sglang/test/test_marlin_utils.py @@ -0,0 +1,171 @@ +""" +Adapted from +https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +""" + +# SPDX-License-Identifier: Apache-2.0 +"""Utility functions used for tests and benchmarks""" + +from typing import Optional + +import numpy as np +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + GPTQ_MARLIN_TILE, + marlin_permute_scales, + marlin_zero_points, +) +from sglang.srt.layers.quantization.scalar_type import ScalarType +from sglang.srt.layers.quantization.utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index feda86934..9be711d12 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state - from sglang.srt.layers.quantization import get_dynamic_override + from sglang.srt.layers.quantization.utils import get_dynamic_override from sglang.srt.model_loader import get_model from sglang.srt.server_args import PortArgs, ServerArgs