From 6249e4a19ed66afa100d55fa41997b725ff4b296 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 13 Jan 2025 04:44:39 -0800 Subject: [PATCH] Revert "Integration of TurboMind AWQ" (#2866) --- python/pyproject.toml | 2 +- python/sglang/srt/configs/model_config.py | 10 +- python/sglang/srt/layers/linear.py | 1 - .../srt/layers/quantization/__init__.py | 2 - .../srt/layers/quantization/awq_turbomind.py | 287 ------------------ .../layers/quantization/turbomind_utils.py | 63 ---- python/sglang/srt/server_args.py | 1 - test/srt/test_turbomind_awq.py | 47 --- 8 files changed, 2 insertions(+), 411 deletions(-) delete mode 100644 python/sglang/srt/layers/quantization/awq_turbomind.py delete mode 100644 python/sglang/srt/layers/quantization/turbomind_utils.py delete mode 100644 test/srt/test_turbomind_awq.py diff --git a/python/pyproject.toml b/python/pyproject.toml index c29580b50..a236469a1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,7 +28,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", "cuda-python", "sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", - "flashinfer==0.1.6", "turbomind" + "flashinfer==0.1.6" ] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 28144f139..072c88b04 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -14,7 +14,6 @@ import json import logging -import sys from enum import IntEnum, auto from typing import List, Optional, Set, Union @@ -231,7 +230,7 @@ class ModelConfig: # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config() - if quant_cfg is not None and not quantization_in_turbomind(self.quantization): + if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it @@ -402,10 +401,3 @@ def is_multimodal_model(model_architectures: List[str]): def is_encoder_decoder_model(model_architectures: List[str]): return "MllamaForConditionalGeneration" in model_architectures - - -def quantization_in_turbomind(quantization: str) -> bool: - if quantization in ["awq_turbomind"]: - return True - else: - return False diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 815255d5c..ee9386c13 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -48,7 +48,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQLinearMethod", "FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod", - "AWQTurbomindLinearMethod", "IPEXAWQLinearMethod", ] diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index faf14d6fd..35b0c4d94 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig -from sglang.srt.layers.quantization.awq_turbomind import AWQTurbomindConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config @@ -38,7 +37,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, - "awq_turbomind": AWQTurbomindConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, diff --git a/python/sglang/srt/layers/quantization/awq_turbomind.py b/python/sglang/srt/layers/quantization/awq_turbomind.py deleted file mode 100644 index 007b20420..000000000 --- a/python/sglang/srt/layers/quantization/awq_turbomind.py +++ /dev/null @@ -1,287 +0,0 @@ -import logging -import os -import sys -from typing import Any, Dict, List, Optional - -import torch -import turbomind -from torch.nn import Parameter - -turbomind_dir = os.path.split(turbomind.__file__)[0] -sys.path.append(os.path.join(turbomind_dir, "lib")) -import _turbomind_ext -from vllm.model_executor.layers.linear import LinearBase - -from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod -from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter -from sglang.srt.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from sglang.srt.layers.quantization.turbomind_utils import ( - get_u4_slices, - is_layer_skipped_awq, - pack_u4_row, - unpack_awq_gemm, - verify_turbomind_supported, -) -from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.utils import is_cuda, set_weight_attrs - -logger = logging.getLogger(__name__) - - -class AWQTurbomindConfig(QuantizationConfig): - """Config class for AWQ Turbomind""" - - def __init__( - self, - weight_bits: int, - group_size: int, - zero_point: bool, - lm_head_quantized: bool, - modules_to_not_convert: Optional[List[str]] = None, - ) -> None: - self.pack_factor = 32 // weight_bits # packed into int32 - self.group_size = group_size - self.zero_point = zero_point - self.lm_head_quantized = lm_head_quantized - self.weight_bits = weight_bits - self.modules_to_not_convert = modules_to_not_convert or [] - - verify_turbomind_supported(self.weight_bits, self.group_size) - - def __repr__(self) -> str: - return ( - f"AWQTurbomindConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point}, " - f"lm_head_quantized={self.lm_head_quantized}, " - f"modules_to_not_convert={self.modules_to_not_convert})" - ) - - @classmethod - def get_name(cls) -> str: - return "awq_turbomind" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - return 70 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQTurbomindConfig": - weight_bits = cls.get_from_keys(config, ["bits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - zero_point = cls.get_from_keys(config, ["zero_point"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None - ) - return cls( - weight_bits, - group_size, - zero_point, - lm_head_quantized, - modules_to_not_convert, - ) - - @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - can_convert = cls.is_awq_turbomind_compatible(hf_quant_cfg) - is_valid_user_quant = user_quant is None or user_quant == "awq_turbomind" - - if can_convert and is_valid_user_quant: - msg = f"The model is convertible to {cls.get_name()} during runtime. Using {cls.get_name()} kernel." - logger.info(msg) - return cls.get_name() - - if can_convert and user_quant == "awq": - logger.info( - "Detected that the model can run with awq_turbomind" - ", however you specified quantization=awq explicitly," - " so forcing awq. Use quantization=awq_turbomind for" - " faster inference" - ) - return None - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - if isinstance(layer, LinearBase) or ( - isinstance(layer, ParallelLMHead) and self.lm_head_quantized - ): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): - return UnquantizedLinearMethod() - return AWQTurbomindLinearMethod(self) - - return None - - @classmethod - def is_awq_turbomind_compatible(cls, quant_config: Dict[str, Any]): - if not is_cuda(): - return False - - # Extract data from quant config. - quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits") - group_size = quant_config.get("group_size") - zero_point = quant_config.get("zero_point") - - if quant_method != "awq": - return False - - # If we cannot find the info needed in the config, cannot convert. - if num_bits is None or group_size is None or zero_point is None: - return False - - return verify_turbomind_supported(quant_bit=num_bits, group_size=group_size) - - def get_scaled_act_names(self) -> List[str]: - return [] - - -class AWQTurbomindLinearMethod(LinearMethodBase): - """Linear method for AWQ Turbomind. - - Args: - quant_config: The AWQ Turbomind quantization config. - """ - - def __init__(self, quant_config: AWQTurbomindConfig) -> None: - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - - # Normalize group_size - if self.quant_config.group_size != -1: - group_size = self.quant_config.group_size - else: - group_size = input_size - - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition, - output_size_per_partition // self.quant_config.pack_factor, - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader, - ) - - num_groups = input_size_per_partition // group_size - - qzeros = PackedvLLMParameter( - data=torch.empty( - num_groups, - output_size_per_partition // self.quant_config.pack_factor, - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader, - ) - - scales = GroupQuantScaleParameter( - data=torch.empty( - num_groups, - output_size_per_partition, - dtype=params_dtype, - ), - input_dim=0, - output_dim=1, - weight_loader=weight_loader, - ) - - layer.register_parameter("qweight", qweight) - layer.register_parameter("qzeros", qzeros) - layer.register_parameter("scales", scales) - - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.num_groups = num_groups - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - qweight_turbomind = unpack_awq_gemm(layer.qweight.data) - qzeros_turbomind = unpack_awq_gemm(layer.qzeros.data) - scales_turbomind = layer.scales.data - - qweight_turbomind = pack_u4_row(qweight_turbomind) - qzeros_turbomind = qzeros_turbomind.to(torch.half) - - device_id = layer.qweight.device.index - properties = torch.cuda.get_device_properties(device_id) - - def is_16xx_series(name): - import re - - pattern = r"GTX 16\d\d" - return bool(re.search(pattern, name)) - - simt = is_16xx_series(properties.name) - qweight_turbomind = qweight_turbomind.contiguous() - scales_turbomind = scales_turbomind.contiguous() - qzeros_turbomind = qzeros_turbomind.contiguous() - - self.linear = _turbomind_ext.Linear( - layer.input_size_per_partition, - layer.output_size_per_partition, - self.quant_config.weight_bits, - self.quant_config.group_size, - ) - - self.linear.post_init( - qweight_turbomind, scales_turbomind, qzeros_turbomind, simt - ) - - layer.qweight = Parameter(qweight_turbomind, requires_grad=False) - layer.scales = Parameter(scales_turbomind, requires_grad=False) - layer.qzeros = Parameter(qzeros_turbomind, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - x = x.view(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (layer.output_size_per_partition,) - out = torch.empty( - (x.shape[0], layer.output_size_per_partition), - dtype=torch.float16, - device=x.device, - ) - stream = torch.cuda.current_stream() - - self.linear.forward(x, out, stream.cuda_stream) - out = torch.from_dlpack(out) - if bias is not None: - out.add_(bias) - - return out.view(out_shape) diff --git a/python/sglang/srt/layers/quantization/turbomind_utils.py b/python/sglang/srt/layers/quantization/turbomind_utils.py deleted file mode 100644 index b8d4b97d0..000000000 --- a/python/sglang/srt/layers/quantization/turbomind_utils.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from typing import List - -import torch - -from sglang.srt.utils import get_device_capability - - -def get_u4_slices(x: torch.Tensor, dtype: torch.dtype) -> List[torch.Tensor]: - assert x.dtype == torch.int32 - xs = [] - for _ in range(8): - xs.append((x & 15).to(dtype)) - x = x >> 4 - return xs - - -def unpack_awq_gemm(x: torch.Tensor) -> torch.Tensor: - """ - The int4 weights are packed into int32: - bit: 31-28 27-24 23-20 19-16 15-12 11-8 7-4 3-0 - weight: int4_1 int4_2 int4_3 int4_4 int4_5 int4_6 int4_7 int4_8 - """ - xs = get_u4_slices(x, torch.uint8) - order = [0, 4, 1, 5, 2, 6, 3, 7] - ys = [xs[i] for i in order] - return torch.stack(ys, dim=-1).view(*x.shape[:-1], -1) - - -def pack_u4_row(x: torch.Tensor) -> torch.Tensor: - assert x.dtype == torch.uint8 - xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1) - a = torch.zeros(xs[0].shape, dtype=torch.int32, device=x.device) - for t in reversed(xs): - a = (a << 4) | t - return a.squeeze(dim=-1) - - -def verify_turbomind_supported(quant_bit: int, group_size: int) -> bool: - - if quant_bit not in [4]: - raise NotImplementedError( - f"[Tubomind] Only 4-bit is supported for now, but got {quant_bit} bit" - ) - if group_size != 128: - raise NotImplementedError( - f"[Tubomind] Only group_size 128 is supported for now, " - f"but got group_size {group_size}" - ) - - major, minor = get_device_capability() - capability = major * 10 + minor - if capability < 70: - raise NotImplementedError( - f"[Tubomind] Only capability >= 70 is supported for now, but got {capability}" - ) - - return True - - -def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): - return any(module_name in prefix for module_name in modules_to_not_convert) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 061d320ef..be85a3670 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -375,7 +375,6 @@ class ServerArgs: "marlin", "gptq_marlin", "awq_marlin", - "awq_turbomind", "bitsandbytes", "gguf", "modelopt", diff --git a/test/srt/test_turbomind_awq.py b/test/srt/test_turbomind_awq.py deleted file mode 100644 index fa2a879d4..000000000 --- a/test/srt/test_turbomind_awq.py +++ /dev/null @@ -1,47 +0,0 @@ -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - popen_launch_server, -) - - -class TestMLA(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--quantization", - "awq_turbomind", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - assert metrics["score"] >= 0.5 - - -if __name__ == "__main__": - unittest.main()