Refactor: move all quantization-related code to srt/layer/quantization (#7989)

This commit is contained in:
Cheng Wan
2025-07-17 00:47:07 -07:00
committed by GitHub
parent 02404a1e35
commit 49b8777460
22 changed files with 1095 additions and 1175 deletions

View File

@@ -1,5 +1,6 @@
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
from contextlib import suppress
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
)
from pydantic import BaseModel
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
@@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format,
should_ignore_layer,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
try:
import vllm
@@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self.config = config
self.packed_modules_mapping = packed_modules_mapping
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
def get_linear_method(self) -> CompressedTensorsLinearMethod:
return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig):
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
@@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
@@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def _get_scheme_from_parts(
self, weight_quant: BaseModel, input_quant: BaseModel
) -> "CompressedTensorsScheme":
) -> CompressedTensorsScheme:
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
@@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_scheme(
self, layer: torch.nn.Module, layer_name: Optional[str] = None
) -> Optional["CompressedTensorsScheme"]:
) -> Optional[CompressedTensorsScheme]:
"""
compressed-tensors supports non uniform in the following way: