diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index e92eaee73..df0658f86 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -10,10 +10,6 @@ import torch try: from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig - from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( - CompressedTensorsW8A8Fp8MoEMethod, - CompressedTensorsWNA16MoEMethod, - ) from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.gguf import GGUFConfig @@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): return original_isinstance(obj, classinfo) builtins.isinstance = patched_isinstance - - -def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): - """ - Monkey patch the apply function of vllm's FusedMoEMethodBase. - Convert sglang arguments to vllm arguments. - """ - original_apply = class_obj.apply - sig = inspect.signature(original_apply) - param_names = list(sig.parameters.keys()) - has_correction_bias = "e_score_correction_bias" in param_names - - def new_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ): - assert activation == "silu" - assert inplace and not no_combine - - kwargs = { - "self": self, - "layer": layer, - "x": x, - "topk_output": topk_output, - } - return original_apply(**kwargs) - - setattr(class_obj, "apply", new_apply) - - -def monkey_patch_quant_configs(): - """Apply all monkey patches in one place.""" - - monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) - monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) - - -# Only apply monkey patches if vllm is available -if VLLM_AVAILABLE: - monkey_patch_quant_configs() diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 274b6184c..faab1ebd5 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im CompressedTensorsMoEMethod, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod -try: - from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import ( - CompressedTensors24, - ) - from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import ( - W4A16SPARSE24_SUPPORTED_BITS, - CompressedTensorsW4A16Sparse24, - ) - from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( - WNA16_SUPPORTED_BITS, - CompressedTensorsWNA16, - ) - - VLLM_AVAILABLE = True -except ImportError: - VLLM_AVAILABLE = False - logger = logging.getLogger(__name__) __all__ = ["CompressedTensorsLinearMethod"] @@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig): # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): - if not VLLM_AVAILABLE: - raise ImportError( - "vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm" - ) - if ( - self.quant_format == CompressionFormat.marlin_24.value - and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS - ): - return CompressedTensorsW4A16Sparse24( - strategy=weight_quant.strategy, - num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size, - ) if ( self.quant_format == CompressionFormat.pack_quantized.value and weight_quant.num_bits in WNA16_SUPPORTED_BITS @@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig): group_size=weight_quant.group_size, actorder=weight_quant.actorder, ) + else: + raise ImportError( + "Other method (CompressedTensorsW4A16Sparse24) is not supported now" + ) if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): @@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig): # note: input_quant can be None if self._is_fp8_w8a16(weight_quant, input_quant): - if not VLLM_AVAILABLE: - raise ImportError( - "vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm" - ) is_static_input_scheme = input_quant and not input_quant.dynamic return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, @@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig): # Find the "target" in the compressed-tensors config # that our layer conforms to. - # TODO (@robertgshaw): add compressed-tensors as dep + # TODO : add compressed-tensors as dep # so we do not have to re-write these functions # need to make accelerate optional in ct to do this @@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant=input_quant, sparsity_scheme=sparsity_scheme, ): - if not VLLM_AVAILABLE: - raise ImportError( - "vllm is not installed, to use CompressedTensors24, please install vllm" - ) - # Have a valid sparsity scheme - # Validate layer is supported by Cutlass 2:4 Kernel - model_compression_config = ( - None - if sparsity_scheme is None or sparsity_scheme.format == "dense" - else self.config - ) - - scheme = CompressedTensors24( - quantized=weight_quant is not None or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant, - model_compression_config=model_compression_config, - ) + raise ImportError("CompressedTensors24 is not supported now") elif weight_quant is None: logger.warning_once( "Acceleration for non-quantized schemes is " diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6a7696c1d..f057e57d4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,7 +6,7 @@ import enum import logging import re from enum import Enum -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING try: from sgl_kernel import fused_marlin_moe @@ -31,9 +31,13 @@ from sglang.srt.environ import envs from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase -from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS, +) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack +from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales from sglang.srt.layers.quantization.utils import ( all_close_1d, per_tensor_dequantize, @@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import ( from sglang.srt.utils import ( get_bool_env_var, get_compiler_backend, + is_cuda, is_hip, set_weight_attrs, ) @@ -57,6 +62,8 @@ if TYPE_CHECKING: ) _is_hip = is_hip() +_is_cuda = is_cuda() + _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: @@ -64,12 +71,9 @@ if _use_aiter: from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1 -try: - import vllm # noqa: F401 - VLLM_AVAILABLE = True -except ImportError: - VLLM_AVAILABLE = False +if _is_cuda: + from sgl_kernel import fused_marlin_moe logger = logging.getLogger(__name__) @@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): weight_quant = quant_config.target_scheme_map["Linear"].get("weights") input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - if not VLLM_AVAILABLE: - raise ImportError( - "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm." - ) + + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ): if self.num_gpu_experts != -1: num_experts = self.num_gpu_experts - # assert ( - # params_dtype == torch.float16 - # ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will @@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): getattr(layer, name).copy_(new_t) del new_t - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]] - ) - return scale_perm, scale_perm_single - - def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int - ): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int - ): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype - ) - for e in range(num_experts): - output[e] = marlin_permute_scales( - s[e], size_k, size_n, group_size, num_bits - ) - return output - - size_k2 = layer.w2_weight_packed.shape[2] - size_k13 = layer.w13_weight_packed.shape[2] - num_experts = layer.w13_weight_g_idx.shape[0] device = layer.w13_weight_g_idx.device @@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): requires_grad=False, ) - from vllm import _custom_ops as vllm_ops - - marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack( + marlin_w13_qweight = gptq_marlin_moe_repack( layer.w13_weight_packed, layer.w13_g_idx_sort_indices, layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w13_weight_packed", marlin_w13_qweight) - marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack( + replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w2_weight_packed", marlin_w2_qweight) + replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( layer.w13_weight_scale, - size_k13, + layer.w13_weight_packed.shape[2], layer.w13_weight_scale.shape[2], self.group_size, - self.num_bits, ) - replace_tensor("w13_weight_scale", marlin_w13_scales) + replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( layer.w2_weight_scale, layer.w2_weight_scale.shape[1] * (self.group_size if self.group_size != -1 else self.packed_factor), - size_k2, + layer.w2_weight_scale.shape[2], self.group_size, - self.num_bits, ) - replace_tensor("w2_weight_scale", marlin_w2_scales) + replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig @@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_weights, topk_ids, router_logits = topk_output - output = torch.ops.vllm.fused_marlin_moe( + output = fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index 2476da700..6d9871917 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 +from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 __all__ = [ "CompressedTensorsScheme", "CompressedTensorsW8A8Fp8", "CompressedTensorsW8A16Fp8", "CompressedTensorsW8A8Int8", + "CompressedTensorsWNA16", + "WNA16_SUPPORTED_BITS", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index af4f1a0e0..35d579de4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import ( from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from sglang.srt.layers.quantization.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) from sglang.srt.layers.quantization.utils import convert_to_channelwise -try: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, - ) - - MARLIN_FP8_AVAILABLE = True -except ImportError: - MARLIN_FP8_AVAILABLE = False - - def apply_fp8_marlin_linear(*args, **kwargs): - raise ImportError("vllm is not installed") - - def prepare_fp8_layer_for_marlin(*args, **kwargs): - raise ImportError("vllm is not installed") - - __all__ = ["CompressedTensorsW8A16Fp8"] SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] @@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - if not MARLIN_FP8_AVAILABLE: - raise ImportError( - "vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm" - ) - @classmethod def get_min_capability(cls) -> int: # ampere and up diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py new file mode 100644 index 000000000..1d28412e8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -0,0 +1,339 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +# yapf conflicts with isort for this block +# yapf: disable +from sglang.srt.layers.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, + permute_param_layout_, +) +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from sglang.srt.layers.quantization.marlin_utils import ( + MarlinLinearLayerConfig, + apply_gptq_marlin_linear, + check_marlin_supports_shape, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, + marlin_zero_points, +) +from sglang.srt.layers.quantization.utils import ( + get_scalar_types, + replace_parameter, + unpack_cols, +) +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel import gptq_marlin_repack + + +ScalarType, scalar_types = get_scalar_types() + +logger = logging.getLogger(__name__) + +__all__ = ["CompressedTensorsWNA16"] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128 +} +WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size == -1 and self.strategy != "channel": + raise ValueError("Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") + + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric else + WNA16_SUPPORTED_TYPES_MAP[num_bits]) + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + self.kernel_config = MarlinLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx + ) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ) + } + + zeros_args = { + "weight_loader": + weight_loader, + "data": + torch.zeros( + output_size_per_partition // self.pack_factor, + scales_and_zp_size, + dtype=torch.int32, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + + if not self.symmetric: + qzeros = PackedColumnParameter(output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + if not self.symmetric: + qzeros = PackedvLLMParameter(input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + if not self.symmetric: + layer.register_parameter("weight_zero_point", qzeros) + + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + self.w_q_name = "weight_packed" + self.w_s_name = "weight_scale" + self.w_zp_name = "weight_zero_point" + self.w_gidx_name = "weight_g_idx" + + device = getattr(layer, self.w_q_name).device + c = self.kernel_config + + check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size, + ) + + row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(device) + + def _transform_param( + layer: torch.nn.Module, name: Optional[str], fn: Callable + ) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, torch.nn.Parameter(new_param.data, requires_grad=False) + ) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = gptq_marlin_repack( + x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales( + x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size, + ) + return x + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name) + ) + _transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + grouped_k = ( + c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 + ) + _transform_param( + layer, + self.w_zp_name, + lambda x: marlin_zero_points( + unpack_cols( + x.t(), + c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1], + ), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ), + ) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + _transform_param(layer, self.w_q_name, transform_w_q) + _transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + c = self.kernel_config + + def _get_weight_params( + layer: torch.nn.Module, + ) -> tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor], # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) + + w_q, w_s, w_zp, w_gidx = _get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias, + ) diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index e0b398c25..9a521d943 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import numpy @@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] USE_FP32_REDUCE_DEFAULT = True +@dataclass +class MarlinLinearLayerConfig: + full_weight_shape: tuple[int, int] # [in, out] + partition_weight_shape: tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + # For binary size and compile time, we don't support the same types for with and # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl