Files
sglang/python/sglang/srt/layers/quantization/__init__.py
2025-01-13 04:44:39 -08:00

125 lines
4.6 KiB
Python

# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"modelopt": ModelOptFp8Config,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(
f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
)
return QUANTIZATION_METHODS[quantization]
def fp8_get_quant_method(self, layer, prefix):
"""Enhanced get_quant_method for FP8 config."""
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
return None
def gptq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None
def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod,
AWQMoEMethod,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None
def apply_monkey_patches():
"""Apply all monkey patches in one place."""
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
# Apply patches when module is imported
apply_monkey_patches()
__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]