Clean up import vllm in quantization/__init__.py (#4834)
This commit is contained in:
@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
|
||||
|
||||
[project.optional-dependencies]
|
||||
runtime_common = [
|
||||
"compressed-tensors",
|
||||
"datasets",
|
||||
"decord",
|
||||
"fastapi",
|
||||
@@ -56,7 +57,12 @@ srt = [
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
||||
srt_hip = [
|
||||
"sglang[runtime_common]",
|
||||
"torch",
|
||||
"vllm==0.6.7.dev2",
|
||||
"outlines==0.1.11"
|
||||
]
|
||||
|
||||
# xpu is not enabled in public vllm and torch whl,
|
||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||
|
||||
@@ -22,11 +22,7 @@ import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
from sglang.srt.layers.quantization import (
|
||||
BASE_QUANTIZATION_METHODS,
|
||||
QUANTIZATION_METHODS,
|
||||
VLLM_AVAILABLE,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -239,12 +235,7 @@ class ModelConfig:
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
# Select supported quantization methods based on vllm availability
|
||||
if VLLM_AVAILABLE:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
else:
|
||||
supported_quantization = [*BASE_QUANTIZATION_METHODS]
|
||||
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = [
|
||||
"awq",
|
||||
"gptq",
|
||||
@@ -282,11 +273,7 @@ class ModelConfig:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
# Only iterate through currently available quantization methods
|
||||
available_methods = (
|
||||
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
|
||||
)
|
||||
for _, method in available_methods.items():
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization
|
||||
)
|
||||
|
||||
@@ -17,12 +17,12 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
@@ -9,12 +9,24 @@ import torch
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsW8A8Fp8MoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config,
|
||||
)
|
||||
@@ -22,24 +34,24 @@ try:
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# Define empty classes as placeholders when vllm is not available
|
||||
class DummyConfig:
|
||||
pass
|
||||
def override_quantization_method(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
||||
DummyConfig
|
||||
)
|
||||
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
||||
GPTQMarlin24Config
|
||||
) = DummyConfig
|
||||
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
||||
DeepSpeedFPConfig
|
||||
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
||||
MarlinConfig
|
||||
) = QQQConfig = Int8TpuConfig = DummyConfig
|
||||
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
# Base quantization methods that don't depend on vllm
|
||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
}
|
||||
|
||||
# Add vllm-dependent methods if available
|
||||
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
|
||||
if VLLM_AVAILABLE:
|
||||
VLLM_QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
}
|
||||
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
||||
# VLLM-dependent quantization methods
|
||||
VLLM_QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
}
|
||||
|
||||
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
f"Invalid quantization method: {quantization}. "
|
||||
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
||||
)
|
||||
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
||||
raise ValueError(
|
||||
f"{quantization} quantization requires some operators from vllm. "
|
||||
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
||||
)
|
||||
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
@@ -153,13 +175,6 @@ def get_linear_quant_method(
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
@@ -186,31 +201,17 @@ def get_linear_quant_method(
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
if not VLLM_AVAILABLE:
|
||||
return None
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
||||
builtins.isinstance = original_isinstance
|
||||
return
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FusedMoE as PatchedFusedMoE,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
||||
)
|
||||
|
||||
def patched_isinstance(obj, classinfo):
|
||||
if classinfo is LinearBase:
|
||||
return original_isinstance(obj, PatchedLinearBase)
|
||||
if classinfo is FusedMoE:
|
||||
return original_isinstance(obj, PatchedFusedMoE)
|
||||
if classinfo is VocabParallelEmbedding:
|
||||
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
||||
return original_isinstance(obj, classinfo)
|
||||
def patched_isinstance(obj, classinfo):
|
||||
if classinfo is LinearBase:
|
||||
return original_isinstance(obj, PatchedLinearBase)
|
||||
if classinfo is FusedMoE:
|
||||
return original_isinstance(obj, PatchedFusedMoE)
|
||||
if classinfo is VocabParallelEmbedding:
|
||||
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
||||
return original_isinstance(obj, classinfo)
|
||||
|
||||
builtins.isinstance = patched_isinstance
|
||||
except ImportError:
|
||||
return
|
||||
builtins.isinstance = patched_isinstance
|
||||
|
||||
|
||||
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
||||
Convert sglang arguments to vllm arguments.
|
||||
"""
|
||||
if not VLLM_AVAILABLE:
|
||||
return
|
||||
original_apply = class_obj.apply
|
||||
sig = inspect.signature(original_apply)
|
||||
param_names = list(sig.parameters.keys())
|
||||
has_correction_bias = "e_score_correction_bias" in param_names
|
||||
|
||||
try:
|
||||
original_apply = class_obj.apply
|
||||
sig = inspect.signature(original_apply)
|
||||
param_names = list(sig.parameters.keys())
|
||||
has_correction_bias = "e_score_correction_bias" in param_names
|
||||
def new_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
assert activation == "silu"
|
||||
assert inplace and not no_combine
|
||||
|
||||
def new_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
assert activation == "silu"
|
||||
assert inplace and not no_combine
|
||||
kwargs = {
|
||||
"self": self,
|
||||
"layer": layer,
|
||||
"x": x,
|
||||
"router_logits": router_logits,
|
||||
"top_k": top_k,
|
||||
"renormalize": renormalize,
|
||||
"use_grouped_topk": use_grouped_topk,
|
||||
"topk_group": topk_group,
|
||||
"num_expert_group": num_expert_group,
|
||||
"custom_routing_function": custom_routing_function,
|
||||
}
|
||||
if correction_bias is not None:
|
||||
if not has_correction_bias:
|
||||
raise ValueError(
|
||||
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
||||
)
|
||||
kwargs["e_score_correction_bias"] = correction_bias
|
||||
return original_apply(**kwargs)
|
||||
|
||||
kwargs = {
|
||||
"self": self,
|
||||
"layer": layer,
|
||||
"x": x,
|
||||
"router_logits": router_logits,
|
||||
"top_k": top_k,
|
||||
"renormalize": renormalize,
|
||||
"use_grouped_topk": use_grouped_topk,
|
||||
"topk_group": topk_group,
|
||||
"num_expert_group": num_expert_group,
|
||||
"custom_routing_function": custom_routing_function,
|
||||
}
|
||||
if correction_bias is not None:
|
||||
if not has_correction_bias:
|
||||
raise ValueError(
|
||||
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
||||
)
|
||||
kwargs["e_score_correction_bias"] = correction_bias
|
||||
return original_apply(**kwargs)
|
||||
|
||||
setattr(class_obj, "apply", new_apply)
|
||||
except (ImportError, AttributeError):
|
||||
return
|
||||
setattr(class_obj, "apply", new_apply)
|
||||
|
||||
|
||||
def monkey_patch_quant_configs():
|
||||
"""Apply all monkey patches in one place."""
|
||||
if not VLLM_AVAILABLE:
|
||||
return
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsW8A8Fp8MoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
|
||||
monkey_patch_moe_apply(AWQMoEMethod)
|
||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||
except ImportError:
|
||||
return
|
||||
monkey_patch_moe_apply(AWQMoEMethod)
|
||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||
|
||||
|
||||
# Only apply monkey patches if vllm is available
|
||||
if VLLM_AVAILABLE:
|
||||
monkey_patch_quant_configs()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import awq_dequantize
|
||||
|
||||
@@ -24,6 +24,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_device_core_count,
|
||||
get_device_name,
|
||||
get_device_sm,
|
||||
@@ -43,7 +44,7 @@ if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
|
||||
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||
_enable_jit_deepgemm = True
|
||||
|
||||
|
||||
|
||||
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
try:
|
||||
import vllm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
||||
|
||||
class scalar_types:
|
||||
uint4b8 = "uint4b8"
|
||||
uint8b128 = "uint8b128"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
) -> Optional[GPTQLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
else:
|
||||
raise ImportError("vllm is not installed")
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
||||
)
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
if not VLLM_AVAILABLE:
|
||||
return False
|
||||
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
|
||||
if not _is_cuda:
|
||||
return False
|
||||
|
||||
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
# Delay import to avoid circular dependency
|
||||
) -> Optional[MarlinLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
|
||||
@@ -53,8 +53,6 @@ class TpModelWorker:
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
):
|
||||
self.worker = self
|
||||
|
||||
# Parse args
|
||||
self.tp_rank = tp_rank
|
||||
|
||||
@@ -134,6 +132,9 @@ class TpModelWorker:
|
||||
)[0]
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# A reference make this class has the same member as TpModelWorkerClient
|
||||
self.worker = self
|
||||
|
||||
def get_worker_info(self):
|
||||
return (
|
||||
self.max_total_num_tokens,
|
||||
|
||||
@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
Reference in New Issue
Block a user