Clean up import vllm in quantization/__init__.py (#4834)

This commit is contained in:
Lianmin Zheng
2025-03-28 10:34:10 -07:00
committed by GitHub
parent ef9a378a20
commit 74e0ac1dbd
14 changed files with 191 additions and 254 deletions

View File

@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
try:
import vllm
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.scalar_type import scalar_types
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
class scalar_types:
uint4b8 = "uint4b8"
uint8b128 = "uint8b128"
logger = logging.getLogger(__name__)
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
) -> Optional[GPTQLinearMethod]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.quantization import get_linear_quant_method
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
if VLLM_AVAILABLE:
from vllm.scalar_type import scalar_types
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
else:
raise ImportError("vllm is not installed")
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
)
# (num_bits, is_sym) -> quant_type
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str:
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
) -> Optional[QuantizeMethodBase]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
if not VLLM_AVAILABLE:
return False
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
if not _is_cuda:
return False
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
# Delay import to avoid circular dependency
) -> Optional[MarlinLinearMethod]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or (