From c28ad1990d29f3993c1eebff06673e819ac4b032 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Thu, 17 Jul 2025 06:56:26 +0800 Subject: [PATCH] [1/n] chore: decouple quantization implementation from vLLM dependency (#7992) --- .../layers/moe/fused_moe_triton/__init__.py | 5 +- .../srt/layers/quantization/__init__.py | 6 +- python/sglang/srt/layers/quantization/gptq.py | 650 +++++++++++---- .../srt/layers/quantization/marlin_utils.py | 781 ++++++++++++++++++ .../srt/layers/quantization/moe_wna16.py | 30 + .../srt/layers/quantization/quant_utils.py | 166 ---- .../srt/layers/quantization}/scalar_type.py | 0 .../sglang/srt/layers/quantization/utils.py | 163 +++- sgl-kernel/python/sgl_kernel/fused_moe.py | 3 +- sgl-kernel/tests/test_marlin_repack.py | 6 +- test/srt/test_gptqmodel_dynamic.py | 9 +- test/srt/test_int4_kernel.py | 301 ------- test/srt/test_w4a8.py | 14 - 13 files changed, 1498 insertions(+), 636 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/marlin_utils.py delete mode 100644 python/sglang/srt/layers/quantization/quant_utils.py rename {sgl-kernel/python/sgl_kernel => python/sglang/srt/layers/quantization}/scalar_type.py (100%) delete mode 100644 test/srt/test_int4_kernel.py delete mode 100644 test/srt/test_w4a8.py diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index b68961931..839b659fe 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -1,10 +1,11 @@ from contextlib import contextmanager from typing import Any, Dict, Optional -import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_experts, get_config_file_name, + moe_align_block_size, + try_get_optimal_moe_config, ) from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, @@ -37,4 +38,6 @@ __all__ = [ "fused_moe", "fused_experts", "get_config_file_name", + "moe_align_block_size", + "try_get_optimal_moe_config", ] diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 4ee498169..7507a5b62 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -22,10 +22,6 @@ try: from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.gguf import GGUFConfig - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod, - ) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config, ) @@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import ( GPTQConfig, + GPTQLinearMethod, GPTQMarlinConfig, + GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) from sglang.srt.layers.quantization.modelopt_quant import ( diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 9e2b3e063..3658d0b85 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,47 +1,55 @@ import logging +from dataclasses import dataclass from fractions import Fraction from typing import Any, Callable, Dict, List, Optional, Union import torch -from sglang.srt.layers.linear import LinearBase, set_weight_attrs +from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs +from sglang.srt.layers.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, + permute_param_layout_, +) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.utils import replace_parameter +from sglang.srt.layers.quantization.marlin_utils import ( + apply_gptq_marlin_linear, + check_marlin_supported, + check_marlin_supports_shape, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_moe_permute_scales, + marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, + marlin_zero_points, + verify_marlin_supported, +) +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types +from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols + +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None + from sglang.srt.utils import is_cuda _is_cuda = is_cuda() -try: - from vllm import _custom_ops as ops - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, - GPTQMarlinLinearMethod, - marlin_moe_permute_scales, - ) - from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, - ) - from vllm.scalar_type import scalar_types +if _is_cuda: + from sgl_kernel import fused_marlin_moe - VLLM_AVAILABLE = True -except ImportError: - VLLM_AVAILABLE = False - - GPTQLinearMethod = MarlinLinearMethod = Any - - FusedMoEMethodBase = QuantizeMethodBase - - class scalar_types: - uint4b8 = "uint4b8" - uint8b128 = "uint8b128" +FusedMoEMethodBase = QuantizeMethodBase logger = logging.getLogger(__name__) @@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool: ) +def gptq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) + for e in range(num_experts): + output[e] = torch.ops.sgl_kernel.gptq_marlin_repack( + b_q_weight[e], perm[e], size_k, size_n, num_bits + ) + return output + + +@dataclass +class MarlinLinearLayerConfig: + full_weight_shape: tuple[int, int] # [in, out] + partition_weight_shape: tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. @@ -151,11 +191,16 @@ class GPTQConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[GPTQLinearMethod]: + ) -> Optional["LinearMethodBase"]: # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization import get_linear_quant_method - return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + if isinstance(layer, LinearBase): + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + elif isinstance(layer, FusedMoE): + raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin") + return None class GPTQMarlinConfig(QuantizationConfig): @@ -313,14 +358,6 @@ class GPTQMarlinConfig(QuantizationConfig): if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) - # TODO: re-enable after SGLang syncs with vllm >= 0.7.3 - # if layer.num_experts > 32: - # # For MoEs with many experts the moe_wna16 kernel is faster - # return MoeWNA16Config.from_config(self.full_config).get_quant_method( - # layer, prefix - # ) - # else: - # return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @classmethod @@ -344,112 +381,439 @@ class GPTQMarlinConfig(QuantizationConfig): if (num_bits, sym) not in cls.TYPE_MAP: return False - assert ( - VLLM_AVAILABLE - ), "vllm is not installed, to use gptq_marlin, please install vllm" - return check_marlin_supported( quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size ) -class MarlinConfig(QuantizationConfig): - """Config class for Marlin. +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. - Reference: https://github.com/IST-DASLab/marlin/tree/master + Args: + quant_config: The GPTQ quantization config. """ - def __init__( + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( self, - group_size: int, - lm_head_quantized: bool, - ) -> None: - # Group size for the quantization. - self.group_size = group_size - self.lm_head_quantized = lm_head_quantized - if self.group_size != 128 and self.group_size != -1: + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin, but got group_size of " - f"{self.group_size}" + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size." + ) + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size." ) - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // 4 + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size - # Tile size used by marlin kernels. - self.tile_size = 16 - - # Min out_features dim - self.min_n_threads = 64 - - # Min in_features dim - self.min_k_threads = 128 - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = 16 - - # Permutation length used by the marlin kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return ( - f"MarlinConfig(group_size={self.group_size}, " - f"lm_head_quantized={self.lm_head_quantized})" - ) - - @classmethod - def get_name(cls) -> str: - return "marlin" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] - - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(group_size, lm_head_quantized) - - @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - is_marlin_format = check_marlin_format(hf_quant_cfg) - - is_valid_user_quant = ( - user_quant is None or user_quant == "gptq" or user_quant == "marlin" - ) - - if is_marlin_format and is_valid_user_quant: - msg = "The model is serialized in {} format. Using {} kernel.".format( - cls.get_name(), cls.get_name() - ) - logger.info(msg) - return cls.get_name() - - return None - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional[MarlinLinearMethod]: - # Delay the import to avoid circular dependency - from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead - - if isinstance(layer, LinearBase) or ( - isinstance(layer, ParallelLMHead) and self.lm_head_quantized + self.use_shuffle = True + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if ( + input_size != input_size_per_partition + and self.quant_config.group_size != -1 ): - return MarlinLinearMethod(self) - return None + if self.quant_config.desc_act: + self.use_shuffle = False + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + g_idx = RowvLLMParameter( + data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) + qzeros_args = { + "data": torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": weight_loader, + } + weight_scale_args = { + "data": torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) + + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if self.use_shuffle: + if self.quant_config.desc_act: + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) + else: + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) + reshaped_x = x.reshape(-1, x.shape[-1]) + + output = ops.gptq_gemm( + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + self.use_shuffle, + self.quant_config.weight_bits, + ) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + _kernel_backends_being_used: set[str] = set() + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + # Verify supported on platform. + verify_marlin_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + self.kernel_config = MarlinLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act, + ) + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if marlin_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + # Activation order + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) + + qzeros_args = { + "data": torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": weight_loader, + } + weight_scale_args = { + "data": torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) + + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, "qweight").device + c = self.kernel_config + + check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size, + ) + + row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + self.w_q_name = "qweight" + self.w_s_name = "scales" + self.w_zp_name = "qzeros" + self.w_gidx_name = "g_idx" + + def _transform_param( + layer: torch.nn.Module, name: Optional[str], fn: Callable + ) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, torch.nn.Parameter(new_param.data, requires_grad=False) + ) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = torch.ops.sgl_kernel.gptq_marlin_repack( + x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales( + x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size, + ) + return x + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name) + ) + _transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + grouped_k = ( + c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 + ) + _transform_param( + layer, + self.w_zp_name, + lambda x: marlin_zero_points( + unpack_cols( + x.t(), + c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1], + ), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ), + ) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + _transform_param(layer, self.w_q_name, transform_w_q) + _transform_param(layer, self.w_s_name, transform_w_s) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + c = self.kernel_config + + def _get_weight_params( + layer: torch.nn.Module, + ) -> tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor], # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) + + w_q, w_s, w_zp, w_gidx = _get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias, + ) class GPTQMarlinMoEMethod(FusedMoEMethodBase): @@ -467,6 +831,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + intermediate_size = extra_weight_attrs.pop("intermediate_size") self.is_k_full = (not self.quant_config.desc_act) or ( @@ -644,20 +1011,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): requires_grad=False, ) # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( + marlin_w13_qweight = gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[2], - self.quant_config.quant_type.size_bits, + self.quant_config.weight_bits, ) replace_parameter(layer, "w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( + marlin_w2_qweight = gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[2], - self.quant_config.quant_type.size_bits, + self.quant_config.weight_bits, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales @@ -698,13 +1065,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", ) -> torch.Tensor: + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.topk import select_experts + assert activation == "silu", "Only SiLU activation is supported." + assert ( + scoring_func == "softmax" + ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -713,11 +1086,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, + correction_bias=e_score_correction_bias, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, @@ -730,6 +1102,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - quant_type_id=self.quant_config.quant_type.id, + num_bits=self.quant_config.weight_bits, is_k_full=self.is_k_full, ).to(orig_dtype) diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py new file mode 100644 index 000000000..503c3d003 --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -0,0 +1,781 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py + +import logging +from typing import Any, Optional + +import numpy +import torch + +from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types +from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.utils import get_device_capability + +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None + +logger = logging.getLogger(__name__) + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, +): + if device_capability is None: + major, minor = get_device_capability() + capability = major * 10 + minor + device_capability = -1 if capability is None else capability + + if device_capability < 80: + return [] + + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types( + False, include_fp_type, device_capability + ) + types1 = query_marlin_supported_quant_types( + True, include_fp_type, device_capability + ) + return types0 + types1 + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4] + else: + # GPTQ style, unsigned + symmetric bias + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> tuple[bool, Optional[str]]: + + if device_capability is None: + major, minor = get_device_capability() + capability = major * 10 + minor + device_capability = -1 if capability is None else capability + + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability + ) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + output_size_per_partition = ( + getattr(layer, "output_size_per_partition", None) or layer.output_size + ) + input_size_per_partition = ( + getattr(layer, "input_size_per_partition", None) or layer.input_size + ) + + return check_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=layer.input_size, + group_size=group_size, + )[0] + + +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + hidden_size = layer.hidden_size + intermediate_size_per_partition = layer.intermediate_size_per_partition + # apply_router_weight_on_input is not supported for moe marlin + supports_router_weight = not layer.apply_router_weight_on_input + # moe marlin requires the activation to be silu + supports_activation = layer.activation == "silu" + + # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) + # down: (n, k) = (hidden_size, intermediate_size_per_partition) + # moe marlin requires n % 128 == 0 and k % 64 == 0 + supports_shape = ( + hidden_size % 128 == 0 + and intermediate_size_per_partition % max(64, group_size) == 0 + ) + supports_group_size = group_size in [-1, 32, 64, 128] + return ( + supports_shape + and supports_group_size + and supports_router_weight + and supports_activation + ) + + +def marlin_make_workspace( + device: torch.device, max_blocks_per_sm: int = 1 +) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros( + sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible." + ) + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + # TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False + if True: + return + # if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + # return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible." + ) + + +def should_use_atomic_add_reduce( + m: int, n: int, k: int, device: torch.device, dtype: torch.dtype +) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + # TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False + if not True: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + lm_head_quantized: bool, + ) -> None: + super().__init__() + + # Group size for the quantization. + self.group_size = group_size + self.lm_head_quantized = lm_head_quantized + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin, but got group_size of " + f"{self.group_size}" + ) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return ( + f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})" + ) + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls(group_size, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = hf_quant_cfg.get( + "checkpoint_format" + ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False) + + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "marlin" + ) + + if is_marlin_format and is_valid_user_quant: + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["MarlinLinearMethod"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + return MarlinLinearMethod(self) + return None + + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}" + ) + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}." + ) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}." + ) + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}." + ) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}." + ) + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2 + ) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError("Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition + * self.quant_config.tile_size + // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader, + ) + + # Determine if channelwise or not + input_groups = ( + 1 + if self.quant_config.group_size == -1 + else input_size_per_partition // self.quant_config.group_size + ) + + weight_scale_args = { + "data": torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // self.quant_config.min_n_threads + ) * self.quant_config.max_parallel + + workspace = BasevLLMParameter( + data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), + weight_loader=weight_loader, + ) + + layer.register_parameter("B", qweight) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False) + layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False) + layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm( + x_2d, qweight, scales, workspace, size_m, size_n, size_k + ) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 0bae43435..fe812595a 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs logger = logging.getLogger(__name__) +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + class MoeWNA16Config(QuantizationConfig): """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" diff --git a/python/sglang/srt/layers/quantization/quant_utils.py b/python/sglang/srt/layers/quantization/quant_utils.py deleted file mode 100644 index 59a1b1fdc..000000000 --- a/python/sglang/srt/layers/quantization/quant_utils.py +++ /dev/null @@ -1,166 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py - -from typing import Optional - -import numpy -import torch -from sgl_kernel.scalar_type import ScalarType - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) diff --git a/sgl-kernel/python/sgl_kernel/scalar_type.py b/python/sglang/srt/layers/quantization/scalar_type.py similarity index 100% rename from sgl-kernel/python/sgl_kernel/scalar_type.py rename to python/sglang/srt/layers/quantization/scalar_type.py diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 40a381f3b..2371208f7 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,11 +1,13 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py from types import MappingProxyType -from typing import List, Mapping, Tuple, Union +from typing import List, Mapping, Optional, Tuple, Union +import numpy import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant +from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu _is_cuda = is_cuda() @@ -143,3 +145,162 @@ def replace_parameter( if not isinstance(new, torch.nn.Parameter): new = torch.nn.Parameter(new, requires_grad=False) mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) diff --git a/sgl-kernel/python/sgl_kernel/fused_moe.py b/sgl-kernel/python/sgl_kernel/fused_moe.py index f9322e228..f825131ac 100644 --- a/sgl-kernel/python/sgl_kernel/fused_moe.py +++ b/sgl-kernel/python/sgl_kernel/fused_moe.py @@ -2,10 +2,11 @@ import functools from typing import Optional import torch -from sgl_kernel.scalar_type import scalar_types def get_scalar_type(num_bits: int, has_zp: bool): + from sglang.srt.layers.quantization.scalar_type import scalar_types + if has_zp: assert num_bits == 4 return scalar_types.uint4 diff --git a/sgl-kernel/tests/test_marlin_repack.py b/sgl-kernel/tests/test_marlin_repack.py index c0f13f46b..c229ae1cd 100644 --- a/sgl-kernel/tests/test_marlin_repack.py +++ b/sgl-kernel/tests/test_marlin_repack.py @@ -1,12 +1,10 @@ -import math - import numpy as np import pytest import torch from sgl_kernel import awq_marlin_repack -from sgl_kernel.scalar_type import scalar_types -from sglang.srt.layers.quantization.quant_utils import ( +from sglang.srt.layers.quantization.scalar_type import scalar_types +from sglang.srt.layers.quantization.utils import ( get_pack_factor, pack_cols, quantize_weights, diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index 284465b8b..feda86934 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): model_config=model_config, load_config=load_config, device_config=device_config ) - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( + from sglang.srt.layers.linear import UnquantizedLinearMethod + from sglang.srt.layers.quantization.gptq import ( + GPTQLinearMethod, GPTQMarlinLinearMethod, ) - from sglang.srt.layers.linear import UnquantizedLinearMethod - linear_method_cls = ( GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) ) @@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--dtype", "float16"], + other_args=["--dtype", "bfloat16"], ) @classmethod diff --git a/test/srt/test_int4_kernel.py b/test/srt/test_int4_kernel.py deleted file mode 100644 index 0665f9b91..000000000 --- a/test/srt/test_int4_kernel.py +++ /dev/null @@ -1,301 +0,0 @@ -import itertools -import sys -import unittest - -import torch - -sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm") - -# from sglang.srt.layers.moe.topk import select_experts -from sgl_kernel import fused_marlin_moe -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - -# from vllm.model_executor.layers. import select_experts -from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize, -) -from vllm.scalar_type import scalar_types - - -def stack_and_dev(tensors: list[torch.Tensor]): - dev = tensors[0].device - return torch.stack(tensors, dim=0).to(dev) - - -def torch_moe(a, w1, w2, score, topk, expert_map): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - if expert_map is not None: - topk_ids = expert_map[topk_ids] - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( - 0, 1 - ) - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - - -def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): - """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" - A = A.to(torch.float32) - B = B.to(torch.float32) - - assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" - - # Reshape input - M = A.numel() // A.shape[-1] - B = B.t() # Transpose weight matrix - N, K = B.shape - origin_C_shape = A.shape[:-1] + (K,) - A = A.reshape(M, N) - # As is per-token [M, 1], Bs is per-column [1, K] - C = torch.matmul(A, B) # [M, K] - C = As * C * Bs.view(1, -1) # Broadcast per-column scale - - return C.reshape(origin_C_shape).to(output_dtype) - - -def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): - """This function performs fused moe with per-column int8 quantization using native torch.""" - - B, D = a.shape - # Perform per-token quantization - a_q, a_s = per_token_quant_int8(a) - # Repeat tokens to match topk - a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - # Also repeat the scale - a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] - - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - - # Calculate routing - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - # Process each expert - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - # First MLP layer: note that a_s is now per-token - inter_out = native_w8a8_per_token_matmul( - a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype - ) - # Activation function - act_out = SiluAndMul().forward_native(inter_out) - # Quantize activation output with per-token - act_out_q, act_out_s = per_token_quant_int8(act_out) - - # Second MLP layer - out[mask] = native_w8a8_per_token_matmul( - act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype - ) - # Apply routing weights and sum - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - - -def marlin_fused_moe( - N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size -): - quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 - if ep_size > 1: - local_e = E // ep_size - e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e] - e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32) - e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) - w1 = w1[e_ids] - w2 = w2[e_ids] - else: - e_map = None - w_ref1_l = [] - qweight1_l = [] - scales1_l = [] - zeros1_l = [] - g_idx1_l = [] - sort_indices1_l = [] - s1_l = [] - for i in range(w1.shape[0]): - test_perm = torch.randperm(n=K) - quant_res = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweight1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None - zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None - sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None - - w_ref2_l = [] - qweight2_l = [] - scales2_l = [] - zeros2_l = [] - g_idx2_l = [] - sort_indices2_l = [] - for i in range(w2.shape[0]): - test_perm = torch.randperm(n=N) - quant_res = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweight2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None - zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None - sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None - - topk_weights, topk_ids = fused_topk(a, score, topk, False) - # topk_weights, topk_ids = FusedMoE.select_experts( - # hidden_states=a, - # router_logits=score, - # top_k=topk, - # num_expert_group=E, - # use_grouped_topk=False, - # renormalize=False, - # topk_group=None, - # ) - - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) - marlin_output = fused_marlin_moe( - a, - qweight1, - qweight2, - scales1, - scales2, - score, - topk_weights, - topk_ids, - global_num_experts=E, - expert_map=e_map, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, - num_bits=num_bits, - is_k_full=True, - ) - return marlin_output, torch_output - - -class TestW8A8Int8FusedMoE(unittest.TestCase): - DTYPES = [torch.float16] - M = [1, 16] - N = [128] - K = [256] - E = [4, 10] - TOP_KS = [2, 4] - BLOCK_SIZE = [[128, 128]] - SEEDS = [0] - NUM_BITS = [4] - EP_SIZE = [1, 4] - - @classmethod - def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is not available") - torch.set_default_device("cuda") - - def _w4a8_int8_fused_moe( - self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size - ): - torch.manual_seed(seed) - a = torch.randn((M, K), dtype=dtype) / 10 - - # Generate int8 weights - w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 - w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 - - score = torch.randn((M, E), dtype=dtype) - - with torch.inference_mode(): - marlin_out, ref_out = marlin_fused_moe( - N=N, - E=E, - K=K, - a=a, - w1=w1_fp16, - w2=w2_fp16, - num_bits=num_bits, - group_size=-1, - act_order=False, - score=score, - topk=topk, - ep_size=ep_size, - ) - # Check results - if ( - torch.mean( - torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32)) - ) - / torch.mean(torch.abs(ref_out.to(torch.float32))) - > 0.1 - ): - print(f"marlin_out: {marlin_out}") - print(f"ref_out: {ref_out}") - print( - torch.mean( - torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32)) - ) - / torch.mean(torch.abs(ref_out.to(torch.float32))) - ) - torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0) - - def test_w4a8_int8_fused_moe(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.E, - self.TOP_KS, - self.BLOCK_SIZE, - self.DTYPES, - self.SEEDS, - self.NUM_BITS, - self.EP_SIZE, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - E=params[3], - topk=params[4], - block_size=params[5], - dtype=params[6], - seed=params[7], - num_bits=params[8], - ep_size=params[9], - ): - self._w4a8_int8_fused_moe(*params) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/test/srt/test_w4a8.py b/test/srt/test_w4a8.py deleted file mode 100644 index 75d41ee5f..000000000 --- a/test/srt/test_w4a8.py +++ /dev/null @@ -1,14 +0,0 @@ -import sgl_kernel -import torch - -x = torch.randn(10, 10, device="cuda") -qweight = torch.randn(10, 10, device="cuda") -s1_scales = torch.randn(10, device="cuda") -input_scales = torch.randn(10, device="cuda") -s1_szeros = torch.randn(10, device="cuda") -input_sum = torch.randn(10, device="cuda") -output_buffer = torch.randn(10, device="cuda") - -torch.ops.sgl_kernel.gemm_forward_cuda.default( - x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer -)