Clean up import vllm in quantization/__init__.py (#4834)
This commit is contained in:
12
.github/workflows/pr-test.yml
vendored
12
.github/workflows/pr-test.yml
vendored
@@ -4,19 +4,15 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "python/pyproject.toml"
|
- "python/**"
|
||||||
- "python/sglang/**"
|
|
||||||
- "test/**"
|
|
||||||
- "docs/**"
|
|
||||||
- "scripts/**"
|
- "scripts/**"
|
||||||
|
- "test/**"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "python/pyproject.toml"
|
- "python/**"
|
||||||
- "python/sglang/**"
|
|
||||||
- "test/**"
|
|
||||||
- "docs/**"
|
|
||||||
- "scripts/**"
|
- "scripts/**"
|
||||||
|
- "test/**"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
version:
|
version:
|
||||||
|
|||||||
12
.github/workflows/vllm-dependency-test.yml
vendored
12
.github/workflows/vllm-dependency-test.yml
vendored
@@ -4,19 +4,15 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "python/pyproject.toml"
|
- "python/**"
|
||||||
- "python/sglang/**"
|
|
||||||
- "test/**"
|
|
||||||
- "docs/**"
|
|
||||||
- "scripts/**"
|
- "scripts/**"
|
||||||
|
- "test/**"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "python/pyproject.toml"
|
- "python/**"
|
||||||
- "python/sglang/**"
|
|
||||||
- "test/**"
|
|
||||||
- "docs/**"
|
|
||||||
- "scripts/**"
|
- "scripts/**"
|
||||||
|
- "test/**"
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: vllm-dependency-test-${{ github.ref }}
|
group: vllm-dependency-test-${{ github.ref }}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
runtime_common = [
|
runtime_common = [
|
||||||
|
"compressed-tensors",
|
||||||
"datasets",
|
"datasets",
|
||||||
"decord",
|
"decord",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
@@ -56,7 +57,12 @@ srt = [
|
|||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||||
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
srt_hip = [
|
||||||
|
"sglang[runtime_common]",
|
||||||
|
"torch",
|
||||||
|
"vllm==0.6.7.dev2",
|
||||||
|
"outlines==0.1.11"
|
||||||
|
]
|
||||||
|
|
||||||
# xpu is not enabled in public vllm and torch whl,
|
# xpu is not enabled in public vllm and torch whl,
|
||||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||||
|
|||||||
@@ -22,11 +22,7 @@ import torch
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||||
from sglang.srt.layers.quantization import (
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||||
BASE_QUANTIZATION_METHODS,
|
|
||||||
QUANTIZATION_METHODS,
|
|
||||||
VLLM_AVAILABLE,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -239,12 +235,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
# Select supported quantization methods based on vllm availability
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
if VLLM_AVAILABLE:
|
|
||||||
supported_quantization = [*QUANTIZATION_METHODS]
|
|
||||||
else:
|
|
||||||
supported_quantization = [*BASE_QUANTIZATION_METHODS]
|
|
||||||
|
|
||||||
rocm_supported_quantization = [
|
rocm_supported_quantization = [
|
||||||
"awq",
|
"awq",
|
||||||
"gptq",
|
"gptq",
|
||||||
@@ -282,11 +273,7 @@ class ModelConfig:
|
|||||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||||
|
|
||||||
# Detect which checkpoint is it
|
# Detect which checkpoint is it
|
||||||
# Only iterate through currently available quantization methods
|
for _, method in QUANTIZATION_METHODS.items():
|
||||||
available_methods = (
|
|
||||||
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
|
|
||||||
)
|
|
||||||
for _, method in available_methods.items():
|
|
||||||
quantization_override = method.override_quantization_method(
|
quantization_override = method.override_quantization_method(
|
||||||
quant_cfg, self.quantization
|
quant_cfg, self.quantization
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -9,12 +9,24 @@ import torch
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
|
AWQMarlinConfig,
|
||||||
|
AWQMoEMethod,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||||
|
CompressedTensorsW8A8Fp8MoEMethod,
|
||||||
|
CompressedTensorsWNA16MoEMethod,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinLinearMethod,
|
||||||
|
GPTQMarlinMoEMethod,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQMarlin24Config,
|
GPTQMarlin24Config,
|
||||||
)
|
)
|
||||||
@@ -22,24 +34,24 @@ try:
|
|||||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
VLLM_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
VLLM_AVAILABLE = False
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
# Define empty classes as placeholders when vllm is not available
|
# Define empty classes as placeholders when vllm is not available
|
||||||
class DummyConfig:
|
class DummyConfig:
|
||||||
pass
|
def override_quantization_method(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
||||||
DummyConfig
|
DeepSpeedFPConfig
|
||||||
)
|
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
||||||
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
MarlinConfig
|
||||||
GPTQMarlin24Config
|
) = QQQConfig = Int8TpuConfig = DummyConfig
|
||||||
) = DummyConfig
|
|
||||||
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
|
||||||
|
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||||
@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
UnquantizedEmbeddingMethod,
|
||||||
|
)
|
||||||
|
|
||||||
# Base quantization methods that don't depend on vllm
|
# Base quantization methods that don't depend on vllm
|
||||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add vllm-dependent methods if available
|
# VLLM-dependent quantization methods
|
||||||
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
|
VLLM_QUANTIZATION_METHODS = {
|
||||||
if VLLM_AVAILABLE:
|
"aqlm": AQLMConfig,
|
||||||
VLLM_QUANTIZATION_METHODS = {
|
"awq": AWQConfig,
|
||||||
"aqlm": AQLMConfig,
|
"deepspeedfp": DeepSpeedFPConfig,
|
||||||
"awq": AWQConfig,
|
"tpu_int8": Int8TpuConfig,
|
||||||
"deepspeedfp": DeepSpeedFPConfig,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
"tpu_int8": Int8TpuConfig,
|
"marlin": MarlinConfig,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"gguf": GGUFConfig,
|
||||||
"marlin": MarlinConfig,
|
"gptq_marlin_24": GPTQMarlin24Config,
|
||||||
"gguf": GGUFConfig,
|
"awq_marlin": AWQMarlinConfig,
|
||||||
"gptq_marlin_24": GPTQMarlin24Config,
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
"awq_marlin": AWQMarlinConfig,
|
"qqq": QQQConfig,
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
"experts_int8": ExpertsInt8Config,
|
||||||
"qqq": QQQConfig,
|
"gptq_marlin": GPTQMarlinConfig,
|
||||||
"experts_int8": ExpertsInt8Config,
|
"gptq": GPTQConfig,
|
||||||
"gptq_marlin": GPTQMarlinConfig,
|
}
|
||||||
"gptq": GPTQConfig,
|
|
||||||
}
|
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
||||||
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||||
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
f"Invalid quantization method: {quantization}. "
|
f"Invalid quantization method: {quantization}. "
|
||||||
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
||||||
)
|
)
|
||||||
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
||||||
|
raise ValueError(
|
||||||
|
f"{quantization} quantization requires some operators from vllm. "
|
||||||
|
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
||||||
|
)
|
||||||
|
|
||||||
return QUANTIZATION_METHODS[quantization]
|
return QUANTIZATION_METHODS[quantization]
|
||||||
|
|
||||||
|
|
||||||
@@ -153,13 +175,6 @@ def get_linear_quant_method(
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
linear_method_cls: type,
|
linear_method_cls: type,
|
||||||
):
|
):
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
||||||
ParallelLMHead,
|
|
||||||
UnquantizedEmbeddingMethod,
|
|
||||||
)
|
|
||||||
|
|
||||||
cloned_config = deepcopy(config)
|
cloned_config = deepcopy(config)
|
||||||
parallel_lm_head_quantized = (
|
parallel_lm_head_quantized = (
|
||||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||||
@@ -186,31 +201,17 @@ def get_linear_quant_method(
|
|||||||
|
|
||||||
|
|
||||||
def gptq_get_quant_method(self, layer, prefix):
|
def gptq_get_quant_method(self, layer, prefix):
|
||||||
if not VLLM_AVAILABLE:
|
if isinstance(layer, FusedMoE):
|
||||||
return None
|
return GPTQMarlinMoEMethod(self)
|
||||||
|
|
||||||
try:
|
if isinstance(self, GPTQConfig):
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
return get_linear_quant_method(
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||||
GPTQMarlinLinearMethod,
|
)
|
||||||
GPTQMarlinMoEMethod,
|
elif isinstance(self, GPTQMarlinConfig):
|
||||||
|
return get_linear_quant_method(
|
||||||
|
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
|
||||||
return GPTQMarlinMoEMethod(self)
|
|
||||||
|
|
||||||
if isinstance(self, GPTQConfig):
|
|
||||||
return get_linear_quant_method(
|
|
||||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
|
||||||
)
|
|
||||||
elif isinstance(self, GPTQMarlinConfig):
|
|
||||||
return get_linear_quant_method(
|
|
||||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|||||||
builtins.isinstance = original_isinstance
|
builtins.isinstance = original_isinstance
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
VocabParallelEmbedding,
|
||||||
VocabParallelEmbedding,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
||||||
FusedMoE as PatchedFusedMoE,
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
)
|
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
)
|
||||||
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
def patched_isinstance(obj, classinfo):
|
def patched_isinstance(obj, classinfo):
|
||||||
if classinfo is LinearBase:
|
if classinfo is LinearBase:
|
||||||
return original_isinstance(obj, PatchedLinearBase)
|
return original_isinstance(obj, PatchedLinearBase)
|
||||||
if classinfo is FusedMoE:
|
if classinfo is FusedMoE:
|
||||||
return original_isinstance(obj, PatchedFusedMoE)
|
return original_isinstance(obj, PatchedFusedMoE)
|
||||||
if classinfo is VocabParallelEmbedding:
|
if classinfo is VocabParallelEmbedding:
|
||||||
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
||||||
return original_isinstance(obj, classinfo)
|
return original_isinstance(obj, classinfo)
|
||||||
|
|
||||||
builtins.isinstance = patched_isinstance
|
builtins.isinstance = patched_isinstance
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||||
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|||||||
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
||||||
Convert sglang arguments to vllm arguments.
|
Convert sglang arguments to vllm arguments.
|
||||||
"""
|
"""
|
||||||
if not VLLM_AVAILABLE:
|
original_apply = class_obj.apply
|
||||||
return
|
sig = inspect.signature(original_apply)
|
||||||
|
param_names = list(sig.parameters.keys())
|
||||||
|
has_correction_bias = "e_score_correction_bias" in param_names
|
||||||
|
|
||||||
try:
|
def new_apply(
|
||||||
original_apply = class_obj.apply
|
self,
|
||||||
sig = inspect.signature(original_apply)
|
layer: torch.nn.Module,
|
||||||
param_names = list(sig.parameters.keys())
|
x: torch.Tensor,
|
||||||
has_correction_bias = "e_score_correction_bias" in param_names
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
):
|
||||||
|
assert activation == "silu"
|
||||||
|
assert inplace and not no_combine
|
||||||
|
|
||||||
def new_apply(
|
kwargs = {
|
||||||
self,
|
"self": self,
|
||||||
layer: torch.nn.Module,
|
"layer": layer,
|
||||||
x: torch.Tensor,
|
"x": x,
|
||||||
router_logits: torch.Tensor,
|
"router_logits": router_logits,
|
||||||
top_k: int,
|
"top_k": top_k,
|
||||||
renormalize: bool,
|
"renormalize": renormalize,
|
||||||
use_grouped_topk: bool,
|
"use_grouped_topk": use_grouped_topk,
|
||||||
topk_group: Optional[int] = None,
|
"topk_group": topk_group,
|
||||||
num_expert_group: Optional[int] = None,
|
"num_expert_group": num_expert_group,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
"custom_routing_function": custom_routing_function,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
}
|
||||||
activation: str = "silu",
|
if correction_bias is not None:
|
||||||
inplace: bool = True,
|
if not has_correction_bias:
|
||||||
no_combine: bool = False,
|
raise ValueError(
|
||||||
):
|
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
||||||
assert activation == "silu"
|
)
|
||||||
assert inplace and not no_combine
|
kwargs["e_score_correction_bias"] = correction_bias
|
||||||
|
return original_apply(**kwargs)
|
||||||
|
|
||||||
kwargs = {
|
setattr(class_obj, "apply", new_apply)
|
||||||
"self": self,
|
|
||||||
"layer": layer,
|
|
||||||
"x": x,
|
|
||||||
"router_logits": router_logits,
|
|
||||||
"top_k": top_k,
|
|
||||||
"renormalize": renormalize,
|
|
||||||
"use_grouped_topk": use_grouped_topk,
|
|
||||||
"topk_group": topk_group,
|
|
||||||
"num_expert_group": num_expert_group,
|
|
||||||
"custom_routing_function": custom_routing_function,
|
|
||||||
}
|
|
||||||
if correction_bias is not None:
|
|
||||||
if not has_correction_bias:
|
|
||||||
raise ValueError(
|
|
||||||
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
|
||||||
)
|
|
||||||
kwargs["e_score_correction_bias"] = correction_bias
|
|
||||||
return original_apply(**kwargs)
|
|
||||||
|
|
||||||
setattr(class_obj, "apply", new_apply)
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_quant_configs():
|
def monkey_patch_quant_configs():
|
||||||
"""Apply all monkey patches in one place."""
|
"""Apply all monkey patches in one place."""
|
||||||
if not VLLM_AVAILABLE:
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
return
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
|
|
||||||
try:
|
monkey_patch_moe_apply(AWQMoEMethod)
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||||
CompressedTensorsW8A8Fp8MoEMethod,
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||||
CompressedTensorsWNA16MoEMethod,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
||||||
GPTQMarlinMoEMethod,
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
|
||||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
|
||||||
|
|
||||||
monkey_patch_moe_apply(AWQMoEMethod)
|
|
||||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
|
||||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
|
||||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
# Only apply monkey patches if vllm is available
|
# Only apply monkey patches if vllm is available
|
||||||
if VLLM_AVAILABLE:
|
if VLLM_AVAILABLE:
|
||||||
monkey_patch_quant_configs()
|
monkey_patch_quant_configs()
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_quantization_config",
|
|
||||||
"QUANTIZATION_METHODS",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import awq_dequantize
|
from sgl_kernel import awq_dequantize
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
|
get_bool_env_var,
|
||||||
get_device_core_count,
|
get_device_core_count,
|
||||||
get_device_name,
|
get_device_name,
|
||||||
get_device_sm,
|
get_device_sm,
|
||||||
@@ -43,7 +44,7 @@ if _is_cuda:
|
|||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||||
|
|
||||||
sm_version = get_device_sm()
|
sm_version = get_device_sm()
|
||||||
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
|
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||||
_enable_jit_deepgemm = True
|
_enable_jit_deepgemm = True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinLinearMethod,
|
||||||
|
GPTQMarlinMoEMethod,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
check_marlin_supported,
|
||||||
|
)
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
VLLM_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
VLLM_AVAILABLE = False
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
||||||
|
|
||||||
|
class scalar_types:
|
||||||
|
uint4b8 = "uint4b8"
|
||||||
|
uint8b128 = "uint8b128"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["GPTQLinearMethod"]:
|
) -> Optional[GPTQLinearMethod]:
|
||||||
if not VLLM_AVAILABLE:
|
# Delay the import to avoid circular dependency
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
|
||||||
|
|
||||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||||
|
|
||||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||||
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
class GPTQMarlinConfig(QuantizationConfig):
|
class GPTQMarlinConfig(QuantizationConfig):
|
||||||
"""Config class for GPTQ Marlin"""
|
"""Config class for GPTQ Marlin"""
|
||||||
|
|
||||||
if VLLM_AVAILABLE:
|
# (num_bits, is_sym) -> quant_type
|
||||||
from vllm.scalar_type import scalar_types
|
TYPE_MAP = {
|
||||||
|
(4, True): scalar_types.uint4b8,
|
||||||
# (num_bits, is_sym) -> quant_type
|
(8, True): scalar_types.uint8b128,
|
||||||
TYPE_MAP = {
|
}
|
||||||
(4, True): scalar_types.uint4b8,
|
|
||||||
(8, True): scalar_types.uint8b128,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# (num_bits, is_sym) -> quant_type
|
||||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
if not VLLM_AVAILABLE:
|
# Delay the import to avoid circular dependency
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
||||||
GPTQMarlinLinearMethod,
|
|
||||||
GPTQMarlinMoEMethod,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||||
|
|
||||||
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
if not VLLM_AVAILABLE:
|
|
||||||
return False
|
|
||||||
|
|
||||||
quant_method = quant_config.get("quant_method", "").lower()
|
quant_method = quant_config.get("quant_method", "").lower()
|
||||||
num_bits = quant_config.get("bits")
|
num_bits = quant_config.get("bits")
|
||||||
group_size = quant_config.get("group_size")
|
group_size = quant_config.get("group_size")
|
||||||
sym = quant_config.get("sym")
|
sym = quant_config.get("sym")
|
||||||
desc_act = quant_config.get("desc_act")
|
desc_act = quant_config.get("desc_act")
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
||||||
check_marlin_supported,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["MarlinLinearMethod"]:
|
) -> Optional[MarlinLinearMethod]:
|
||||||
if not VLLM_AVAILABLE:
|
# Delay the import to avoid circular dependency
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
|
||||||
|
|
||||||
# Delay import to avoid circular dependency
|
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
|
||||||
if isinstance(layer, LinearBase) or (
|
if isinstance(layer, LinearBase) or (
|
||||||
|
|||||||
@@ -53,8 +53,6 @@ class TpModelWorker:
|
|||||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||||
):
|
):
|
||||||
self.worker = self
|
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
|
|
||||||
@@ -134,6 +132,9 @@ class TpModelWorker:
|
|||||||
)[0]
|
)[0]
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
|
# A reference make this class has the same member as TpModelWorkerClient
|
||||||
|
self.worker = self
|
||||||
|
|
||||||
def get_worker_info(self):
|
def get_worker_info(self):
|
||||||
return (
|
return (
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
# Install the dependency in CI.
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
|
|
||||||
# Install the dependency in CI.
|
# Use repo from environment variables, passed from GitHub Actions
|
||||||
|
|
||||||
|
|
||||||
# Use repo from environment variable, passed from GitHub Actions
|
|
||||||
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
|
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
|
||||||
|
|
||||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
@@ -17,17 +15,12 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2
|
|||||||
rm -rf /root/.cache/flashinfer
|
rm -rf /root/.cache/flashinfer
|
||||||
# Force reinstall flashinfer and torch_memory_saver
|
# Force reinstall flashinfer and torch_memory_saver
|
||||||
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
|
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
|
||||||
|
pip install sgl-kernel==0.0.5.post3 --force-reinstall
|
||||||
|
|
||||||
pip install torch_memory_saver --force-reinstall
|
pip install torch_memory_saver
|
||||||
|
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm
|
||||||
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
|
|
||||||
|
|
||||||
# For compling xgrammar kernels
|
# For compling xgrammar kernels
|
||||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||||
|
|
||||||
# For DeepSeek-VL2
|
|
||||||
pip install timm
|
|
||||||
|
|
||||||
pip install sgl-kernel==0.0.5.post3 --force-reinstall
|
|
||||||
|
|
||||||
pip uninstall vllm -y || true
|
pip uninstall vllm -y || true
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class TestEAGLEEngine(CustomTestCase):
|
|||||||
"mem_fraction_static": 0.7,
|
"mem_fraction_static": 0.7,
|
||||||
"cuda_graph_max_bs": 4,
|
"cuda_graph_max_bs": 4,
|
||||||
}
|
}
|
||||||
NUM_CONFIGS = 3
|
NUM_CONFIGS = 2
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.prompt = "Today is a sunny day and I like"
|
self.prompt = "Today is a sunny day and I like"
|
||||||
@@ -61,8 +61,6 @@ class TestEAGLEEngine(CustomTestCase):
|
|||||||
configs = [
|
configs = [
|
||||||
# Basic config
|
# Basic config
|
||||||
self.BASE_CONFIG,
|
self.BASE_CONFIG,
|
||||||
# Disable cuda graph
|
|
||||||
{**self.BASE_CONFIG, "disable_cuda_graph": True},
|
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
{**self.BASE_CONFIG, "chunked_prefill_size": 4},
|
{**self.BASE_CONFIG, "chunked_prefill_size": 4},
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestTritonAttnBackend(CustomTestCase):
|
|||||||
"triton",
|
"triton",
|
||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
16,
|
4,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user