Refactor: move all quantization-related code to srt/layer/quantization (#7989)
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -147,6 +154,94 @@ def replace_parameter(
|
||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||
)
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||
)
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool, None] = None,
|
||||
) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.unquant import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override(cloned_config, layer_name=prefix) is False:
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
Reference in New Issue
Block a user