Refactor: move all quantization-related code to srt/layer/quantization (#7989)

This commit is contained in:
Cheng Wan
2025-07-17 00:47:07 -07:00
committed by GitHub
parent 02404a1e35
commit 49b8777460
22 changed files with 1095 additions and 1175 deletions

View File

@@ -1,7 +1,11 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from __future__ import annotations
import re
from copy import deepcopy
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
import numpy
import torch
@@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
@@ -147,6 +154,94 @@ def replace_parameter(
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.unquant import (
UnquantizedEmbeddingMethod,
UnquantizedLinearMethod,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
)
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if get_dynamic_override(cloned_config, layer_name=prefix) is False:
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits