Clean up import vllm in quantization/__init__.py (#4834)
This commit is contained in:
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
try:
|
||||
import vllm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
||||
|
||||
class scalar_types:
|
||||
uint4b8 = "uint4b8"
|
||||
uint8b128 = "uint8b128"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
) -> Optional[GPTQLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
else:
|
||||
raise ImportError("vllm is not installed")
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
||||
)
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
if not VLLM_AVAILABLE:
|
||||
return False
|
||||
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
|
||||
if not _is_cuda:
|
||||
return False
|
||||
|
||||
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
# Delay import to avoid circular dependency
|
||||
) -> Optional[MarlinLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
|
||||
Reference in New Issue
Block a user