Refactor: move all quantization-related code to srt/layer/quantization (#7989)

This commit is contained in:
Cheng Wan
2025-07-17 00:47:07 -07:00
committed by GitHub
parent 02404a1e35
commit 49b8777460
22 changed files with 1095 additions and 1175 deletions

View File

@@ -1,16 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
import torch
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
@@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig):
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
@@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["LinearMethodBase"]:
) -> Optional[LinearMethodBase]:
from sglang.srt.layers.linear import LinearBase
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):