Clean up import vllm in quantization/__init__.py (#4834)
This commit is contained in:
@@ -9,12 +9,24 @@ import torch
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsW8A8Fp8MoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config,
|
||||
)
|
||||
@@ -22,24 +34,24 @@ try:
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# Define empty classes as placeholders when vllm is not available
|
||||
class DummyConfig:
|
||||
pass
|
||||
def override_quantization_method(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
||||
DummyConfig
|
||||
)
|
||||
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
||||
GPTQMarlin24Config
|
||||
) = DummyConfig
|
||||
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
||||
DeepSpeedFPConfig
|
||||
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
||||
MarlinConfig
|
||||
) = QQQConfig = Int8TpuConfig = DummyConfig
|
||||
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
# Base quantization methods that don't depend on vllm
|
||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
}
|
||||
|
||||
# Add vllm-dependent methods if available
|
||||
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
|
||||
if VLLM_AVAILABLE:
|
||||
VLLM_QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
}
|
||||
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
||||
# VLLM-dependent quantization methods
|
||||
VLLM_QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
}
|
||||
|
||||
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
f"Invalid quantization method: {quantization}. "
|
||||
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
||||
)
|
||||
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
||||
raise ValueError(
|
||||
f"{quantization} quantization requires some operators from vllm. "
|
||||
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
||||
)
|
||||
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
@@ -153,13 +175,6 @@ def get_linear_quant_method(
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
@@ -186,31 +201,17 @@ def get_linear_quant_method(
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
if not VLLM_AVAILABLE:
|
||||
return None
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
||||
builtins.isinstance = original_isinstance
|
||||
return
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FusedMoE as PatchedFusedMoE,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
||||
)
|
||||
|
||||
def patched_isinstance(obj, classinfo):
|
||||
if classinfo is LinearBase:
|
||||
return original_isinstance(obj, PatchedLinearBase)
|
||||
if classinfo is FusedMoE:
|
||||
return original_isinstance(obj, PatchedFusedMoE)
|
||||
if classinfo is VocabParallelEmbedding:
|
||||
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
||||
return original_isinstance(obj, classinfo)
|
||||
def patched_isinstance(obj, classinfo):
|
||||
if classinfo is LinearBase:
|
||||
return original_isinstance(obj, PatchedLinearBase)
|
||||
if classinfo is FusedMoE:
|
||||
return original_isinstance(obj, PatchedFusedMoE)
|
||||
if classinfo is VocabParallelEmbedding:
|
||||
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
||||
return original_isinstance(obj, classinfo)
|
||||
|
||||
builtins.isinstance = patched_isinstance
|
||||
except ImportError:
|
||||
return
|
||||
builtins.isinstance = patched_isinstance
|
||||
|
||||
|
||||
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
||||
Convert sglang arguments to vllm arguments.
|
||||
"""
|
||||
if not VLLM_AVAILABLE:
|
||||
return
|
||||
original_apply = class_obj.apply
|
||||
sig = inspect.signature(original_apply)
|
||||
param_names = list(sig.parameters.keys())
|
||||
has_correction_bias = "e_score_correction_bias" in param_names
|
||||
|
||||
try:
|
||||
original_apply = class_obj.apply
|
||||
sig = inspect.signature(original_apply)
|
||||
param_names = list(sig.parameters.keys())
|
||||
has_correction_bias = "e_score_correction_bias" in param_names
|
||||
def new_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
assert activation == "silu"
|
||||
assert inplace and not no_combine
|
||||
|
||||
def new_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
assert activation == "silu"
|
||||
assert inplace and not no_combine
|
||||
kwargs = {
|
||||
"self": self,
|
||||
"layer": layer,
|
||||
"x": x,
|
||||
"router_logits": router_logits,
|
||||
"top_k": top_k,
|
||||
"renormalize": renormalize,
|
||||
"use_grouped_topk": use_grouped_topk,
|
||||
"topk_group": topk_group,
|
||||
"num_expert_group": num_expert_group,
|
||||
"custom_routing_function": custom_routing_function,
|
||||
}
|
||||
if correction_bias is not None:
|
||||
if not has_correction_bias:
|
||||
raise ValueError(
|
||||
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
||||
)
|
||||
kwargs["e_score_correction_bias"] = correction_bias
|
||||
return original_apply(**kwargs)
|
||||
|
||||
kwargs = {
|
||||
"self": self,
|
||||
"layer": layer,
|
||||
"x": x,
|
||||
"router_logits": router_logits,
|
||||
"top_k": top_k,
|
||||
"renormalize": renormalize,
|
||||
"use_grouped_topk": use_grouped_topk,
|
||||
"topk_group": topk_group,
|
||||
"num_expert_group": num_expert_group,
|
||||
"custom_routing_function": custom_routing_function,
|
||||
}
|
||||
if correction_bias is not None:
|
||||
if not has_correction_bias:
|
||||
raise ValueError(
|
||||
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
||||
)
|
||||
kwargs["e_score_correction_bias"] = correction_bias
|
||||
return original_apply(**kwargs)
|
||||
|
||||
setattr(class_obj, "apply", new_apply)
|
||||
except (ImportError, AttributeError):
|
||||
return
|
||||
setattr(class_obj, "apply", new_apply)
|
||||
|
||||
|
||||
def monkey_patch_quant_configs():
|
||||
"""Apply all monkey patches in one place."""
|
||||
if not VLLM_AVAILABLE:
|
||||
return
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsW8A8Fp8MoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
|
||||
monkey_patch_moe_apply(AWQMoEMethod)
|
||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||
except ImportError:
|
||||
return
|
||||
monkey_patch_moe_apply(AWQMoEMethod)
|
||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||
|
||||
|
||||
# Only apply monkey patches if vllm is available
|
||||
if VLLM_AVAILABLE:
|
||||
monkey_patch_quant_configs()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import awq_dequantize
|
||||
|
||||
@@ -24,6 +24,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_device_core_count,
|
||||
get_device_name,
|
||||
get_device_sm,
|
||||
@@ -43,7 +44,7 @@ if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
|
||||
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||
_enable_jit_deepgemm = True
|
||||
|
||||
|
||||
|
||||
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
try:
|
||||
import vllm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
||||
|
||||
class scalar_types:
|
||||
uint4b8 = "uint4b8"
|
||||
uint8b128 = "uint8b128"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
) -> Optional[GPTQLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
else:
|
||||
raise ImportError("vllm is not installed")
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
||||
)
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
if not VLLM_AVAILABLE:
|
||||
return False
|
||||
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
|
||||
if not _is_cuda:
|
||||
return False
|
||||
|
||||
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
# Delay import to avoid circular dependency
|
||||
) -> Optional[MarlinLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
|
||||
Reference in New Issue
Block a user