Refactor: move all quantization-related code to srt/layer/quantization (#7989)
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
@@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["LinearMethodBase"]:
|
||||
) -> Optional[LinearMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
|
||||
Reference in New Issue
Block a user