Refactor: move all quantization-related code to srt/layer/quantization (#7989)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
|
||||
try:
|
||||
import vllm
|
||||
@@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.config = config
|
||||
self.packed_modules_mapping = packed_modules_mapping
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
def get_linear_method(self) -> CompressedTensorsLinearMethod:
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
@@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
@@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
|
||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
@@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||
) -> "CompressedTensorsScheme":
|
||||
) -> CompressedTensorsScheme:
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
@@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
) -> Optional[CompressedTensorsScheme]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user