2024-09-19 20:53:11 +08:00
|
|
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2025-02-12 22:09:52 +08:00
|
|
|
from typing import Callable, Dict, Optional, Type
|
2024-09-19 20:53:11 +08:00
|
|
|
|
2025-02-12 22:09:52 +08:00
|
|
|
import torch
|
2024-09-19 20:53:11 +08:00
|
|
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
2025-02-12 22:09:52 +08:00
|
|
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
|
|
|
|
AWQMarlinConfig,
|
|
|
|
|
AWQMoEMethod,
|
|
|
|
|
)
|
2024-09-19 20:53:11 +08:00
|
|
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
2024-11-25 17:06:36 +08:00
|
|
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
2024-09-19 20:53:11 +08:00
|
|
|
CompressedTensorsConfig,
|
|
|
|
|
)
|
|
|
|
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
|
|
|
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
|
|
|
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
|
|
|
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|
|
|
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
|
|
|
|
|
|
|
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
2024-12-07 19:28:53 +08:00
|
|
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
2025-01-08 18:04:30 +08:00
|
|
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
2025-01-14 17:07:49 +08:00
|
|
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
2024-09-19 20:53:11 +08:00
|
|
|
|
|
|
|
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
|
|
|
"aqlm": AQLMConfig,
|
|
|
|
|
"awq": AWQConfig,
|
|
|
|
|
"deepspeedfp": DeepSpeedFPConfig,
|
|
|
|
|
"tpu_int8": Int8TpuConfig,
|
|
|
|
|
"fp8": Fp8Config,
|
|
|
|
|
"fbgemm_fp8": FBGEMMFp8Config,
|
|
|
|
|
"marlin": MarlinConfig,
|
2025-01-06 14:54:52 -08:00
|
|
|
"modelopt": ModelOptFp8Config,
|
2024-09-19 20:53:11 +08:00
|
|
|
"gguf": GGUFConfig,
|
|
|
|
|
"gptq_marlin_24": GPTQMarlin24Config,
|
|
|
|
|
"gptq_marlin": GPTQMarlinConfig,
|
|
|
|
|
"awq_marlin": AWQMarlinConfig,
|
|
|
|
|
"gptq": GPTQConfig,
|
|
|
|
|
"compressed-tensors": CompressedTensorsConfig,
|
|
|
|
|
"bitsandbytes": BitsAndBytesConfig,
|
|
|
|
|
"qqq": QQQConfig,
|
|
|
|
|
"experts_int8": ExpertsInt8Config,
|
2025-01-14 17:07:49 +08:00
|
|
|
"w8a8_int8": W8A8Int8Config,
|
2024-09-19 20:53:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
|
|
|
if quantization not in QUANTIZATION_METHODS:
|
2024-11-25 17:06:36 +08:00
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid quantization method: {quantization}. "
|
|
|
|
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
|
|
|
|
)
|
2024-09-19 20:53:11 +08:00
|
|
|
return QUANTIZATION_METHODS[quantization]
|
|
|
|
|
|
|
|
|
|
|
2024-12-03 07:12:33 -08:00
|
|
|
def gptq_get_quant_method(self, layer, prefix):
|
|
|
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
|
|
|
GPTQMarlinLinearMethod,
|
|
|
|
|
GPTQMarlinMoEMethod,
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-16 18:00:03 +08:00
|
|
|
from sglang.srt.layers.linear import LinearBase
|
2024-12-24 01:10:22 +08:00
|
|
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
2024-12-03 07:12:33 -08:00
|
|
|
|
|
|
|
|
if isinstance(layer, LinearBase):
|
|
|
|
|
return GPTQMarlinLinearMethod(self)
|
|
|
|
|
elif isinstance(layer, FusedMoE):
|
|
|
|
|
return GPTQMarlinMoEMethod(self)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def awq_get_quant_method(self, layer, prefix):
|
2025-02-12 22:09:52 +08:00
|
|
|
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
2024-12-03 07:12:33 -08:00
|
|
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
|
|
|
|
AWQMarlinLinearMethod,
|
|
|
|
|
AWQMoEMethod,
|
|
|
|
|
)
|
|
|
|
|
|
2025-02-12 22:09:52 +08:00
|
|
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
2024-12-24 01:10:22 +08:00
|
|
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
2025-02-12 22:09:52 +08:00
|
|
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
2024-12-03 07:12:33 -08:00
|
|
|
|
2025-02-12 22:09:52 +08:00
|
|
|
if isinstance(layer, LinearBase) or (
|
|
|
|
|
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
|
|
|
|
):
|
|
|
|
|
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
|
|
|
|
return UnquantizedLinearMethod()
|
2024-12-03 07:12:33 -08:00
|
|
|
return AWQMarlinLinearMethod(self)
|
|
|
|
|
elif isinstance(layer, FusedMoE):
|
|
|
|
|
return AWQMoEMethod(self)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
2025-02-12 22:09:52 +08:00
|
|
|
original_awq_moe_method_apply = AWQMoEMethod.apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def awq_moe_method_apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
top_k: int,
|
|
|
|
|
renormalize: bool,
|
|
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
topk_group: Optional[int] = None,
|
|
|
|
|
num_expert_group: Optional[int] = None,
|
|
|
|
|
custom_routing_function: Optional[Callable] = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
|
|
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
return original_awq_moe_method_apply(
|
|
|
|
|
self,
|
|
|
|
|
layer,
|
|
|
|
|
x,
|
|
|
|
|
router_logits,
|
|
|
|
|
top_k,
|
|
|
|
|
renormalize,
|
|
|
|
|
use_grouped_topk,
|
|
|
|
|
topk_group,
|
|
|
|
|
num_expert_group,
|
|
|
|
|
custom_routing_function,
|
|
|
|
|
scoring_func,
|
|
|
|
|
e_score_correction_bias,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-01-16 18:00:03 +08:00
|
|
|
def patch_vllm_linear_base_isinstance():
|
|
|
|
|
import builtins
|
|
|
|
|
|
|
|
|
|
from vllm.model_executor.layers.linear import LinearBase
|
|
|
|
|
|
|
|
|
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
|
|
|
|
|
|
|
|
|
original_isinstance = builtins.isinstance
|
|
|
|
|
|
|
|
|
|
def patched_isinstance(obj, classinfo):
|
|
|
|
|
if classinfo is LinearBase:
|
|
|
|
|
return original_isinstance(obj, PatchedLinearBase)
|
|
|
|
|
return original_isinstance(obj, classinfo)
|
|
|
|
|
|
|
|
|
|
builtins.isinstance = patched_isinstance
|
|
|
|
|
|
|
|
|
|
|
2024-11-25 17:06:36 +08:00
|
|
|
def apply_monkey_patches():
|
|
|
|
|
"""Apply all monkey patches in one place."""
|
2025-02-12 22:09:52 +08:00
|
|
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
|
|
|
|
|
2024-12-03 07:12:33 -08:00
|
|
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
|
|
|
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
2025-02-12 22:09:52 +08:00
|
|
|
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
2024-11-25 17:06:36 +08:00
|
|
|
|
|
|
|
|
|
2025-01-16 18:00:03 +08:00
|
|
|
patch_vllm_linear_base_isinstance()
|
2024-11-25 17:06:36 +08:00
|
|
|
# Apply patches when module is imported
|
|
|
|
|
apply_monkey_patches()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
"QuantizationConfig",
|
|
|
|
|
"get_quantization_config",
|
|
|
|
|
"QUANTIZATION_METHODS",
|
|
|
|
|
]
|