@@ -450,7 +450,6 @@ class ModelConfig:
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
"auto-round",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
|
||||
@@ -192,14 +192,6 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
|
||||
moe_quant_params = {}
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
"GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
||||
@@ -251,7 +243,6 @@ class FusedMoE(torch.nn.Module):
|
||||
else self.weight_loader_fused
|
||||
),
|
||||
with_bias=with_bias,
|
||||
**moe_quant_params,
|
||||
)
|
||||
|
||||
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
||||
|
||||
@@ -41,7 +41,6 @@ except ImportError as e:
|
||||
)
|
||||
|
||||
|
||||
from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
@@ -87,7 +86,6 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"auto-round": AutoRoundConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,360 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable, Mapping
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
||||
|
||||
ScalarType, scalar_types = get_scalar_types()
|
||||
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
|
||||
class AutoRoundConfig(QuantizationConfig):
|
||||
"""Config class for AutoRound.
|
||||
Reference: https://arxiv.org/pdf/2309.05516
|
||||
"""
|
||||
|
||||
SUPPORTED_BITS = {2, 3, 4, 8}
|
||||
SUPPORTED_DTYPES = {"int"}
|
||||
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
|
||||
SUPPORTED_BACKENDS = {"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
sym: bool = True,
|
||||
packing_format: str = "auto_round:auto_gptq",
|
||||
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
|
||||
extra_config: Optional[dict[str, Any]] = None,
|
||||
data_type: str = "int",
|
||||
backend: str = "auto",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if weight_bits not in self.SUPPORTED_BITS:
|
||||
raise ValueError(
|
||||
f"Unsupported weight_bits: {weight_bits}, "
|
||||
f"currently only support {self.SUPPORTED_BITS}"
|
||||
)
|
||||
if data_type not in self.SUPPORTED_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported data_type: {data_type},"
|
||||
f" currently only support {self.SUPPORTED_DTYPES}"
|
||||
)
|
||||
if packing_format not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported packing_format: {packing_format}, "
|
||||
f"currently only support {self.SUPPORTED_FORMATS}"
|
||||
)
|
||||
if backend not in self.SUPPORTED_BACKENDS:
|
||||
raise ValueError(
|
||||
f"Unsupported backend: {backend}, "
|
||||
f"currently only support {self.SUPPORTED_BACKENDS}"
|
||||
)
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.sym = sym
|
||||
self.packing_format = packing_format
|
||||
self.block_name_to_quantize = (
|
||||
block_name_to_quantize.split(",")
|
||||
if isinstance(block_name_to_quantize, str)
|
||||
else block_name_to_quantize
|
||||
)
|
||||
self.extra_config = extra_config
|
||||
self.data_type = data_type
|
||||
self.backend = backend
|
||||
self.pack_factor = Fraction(32, weight_bits)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AutoRoundConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, sym={self.sym})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls): ## use str will trigger preci issue
|
||||
return "auto-round"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantization_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
|
||||
return cls(
|
||||
weight_bits=cls.get_from_keys(config, ["bits"]),
|
||||
group_size=cls.get_from_keys(config, ["group_size"]),
|
||||
sym=cls.get_from_keys(config, ["sym"]),
|
||||
packing_format=cls.get_from_keys_or(
|
||||
config,
|
||||
["packing_format"],
|
||||
"auto_round:auto_gptq",
|
||||
),
|
||||
block_name_to_quantize=cls.get_from_keys_or(
|
||||
config, ["block_name_to_quantize", "to_quant_block_names"], None
|
||||
),
|
||||
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
|
||||
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
|
||||
backend=cls.get_from_keys_or(
|
||||
config, ["backend", "vllm_backend", "sglang_backend"], "auto"
|
||||
),
|
||||
)
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_layer_config(self, layer, layer_name: str):
|
||||
|
||||
def get_config(name: str, quantized: bool = True):
|
||||
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
|
||||
return (
|
||||
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||
cfg.get("group_size", self.group_size if quantized else -1),
|
||||
cfg.get("sym", self.sym if quantized else True),
|
||||
)
|
||||
|
||||
# 1. Exact match from config
|
||||
if self.extra_config and layer_name in self.extra_config:
|
||||
return get_config(layer_name)
|
||||
|
||||
# 2. Determine whether layer should be quantized
|
||||
quantized = not isinstance(layer, ParallelLMHead)
|
||||
if self.block_name_to_quantize:
|
||||
quantized = any(
|
||||
layer_name.startswith(name) for name in self.block_name_to_quantize
|
||||
)
|
||||
|
||||
# 3. Handle fused MoE
|
||||
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
|
||||
moe_configs = [
|
||||
get_config(name, quantized)
|
||||
for name in self.extra_config
|
||||
if name.startswith(layer_name)
|
||||
]
|
||||
if moe_configs:
|
||||
if len(set(moe_configs)) == 1:
|
||||
return moe_configs[0]
|
||||
raise ValueError(
|
||||
f"Fused MoE layer '{layer_name}' requires "
|
||||
f"consistent quant config for all sub-layers"
|
||||
)
|
||||
|
||||
# 4. Handle fused QKV or other patterns
|
||||
if self.extra_config:
|
||||
for fusion_key, sub_keys in self.packed_modules_mapping.items():
|
||||
if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
|
||||
sub_names = [
|
||||
layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
|
||||
]
|
||||
sub_configs = [get_config(name, quantized) for name in sub_names]
|
||||
if len(set(sub_configs)) == 1:
|
||||
return sub_configs[0]
|
||||
raise ValueError(
|
||||
f"Fused module '{layer_name}' requires "
|
||||
f"consistent quant config for {sub_names}"
|
||||
)
|
||||
|
||||
# 5. Fallback
|
||||
return get_config(layer_name, quantized)
|
||||
|
||||
def check_quantized(self, weight_bits: int) -> bool:
|
||||
return weight_bits < 16
|
||||
|
||||
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_marlin_supports_layer,
|
||||
check_moe_marlin_supports_layer,
|
||||
)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
AWQ_TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
|
||||
AWQ_TYPE_MAP[weight_bits], group_size, not sym
|
||||
)
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size
|
||||
)
|
||||
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from sglang.srt.layers.quantization.awq import (
|
||||
AWQMarlinConfig,
|
||||
AWQMarlinLinearMethod,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
|
||||
quant_args_marlin = AWQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
lm_head_quantized=False,
|
||||
full_config={},
|
||||
modules_to_not_convert=[],
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig, AWQLinearMethod
|
||||
|
||||
quant_args = AWQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMoEMethod(quant_args_marlin)
|
||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"zero_point": not sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return AWQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return AWQLinearMethod(quant_args)
|
||||
return None
|
||||
|
||||
def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_moe_marlin_supports_layer,
|
||||
)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
GPTQ_TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
|
||||
GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
|
||||
)
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size
|
||||
)
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from sglang.srt.layers.quantization.gptq import (
|
||||
GPTQMarlinConfig,
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
quant_args_marlin = GPTQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
full_config={},
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQLinearMethod
|
||||
|
||||
quant_args = GPTQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"sym": sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return GPTQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return GPTQLinearMethod(quant_args)
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
||||
# TODO enable CPU quant method later
|
||||
if "gptq" in self.packing_format or "gptq" in self.backend:
|
||||
return self.apply_gptq_quant_layer(layer, prefix)
|
||||
if "awq" in self.packing_format or "awq" in self.backend:
|
||||
return self.apply_awq_quant_layer(layer, prefix)
|
||||
@@ -80,7 +80,6 @@ QUANTIZATION_CHOICES = [
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"mxfp4",
|
||||
"auto-round",
|
||||
]
|
||||
|
||||
ATTENTION_BACKEND_CHOICES = [
|
||||
|
||||
@@ -87,11 +87,6 @@ DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-V3-0324"
|
||||
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
|
||||
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
|
||||
)
|
||||
DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST = (
|
||||
"OPEA/Llama-3.2-11B-Vision-Instruct-qvision-int4-sym-inc", ## mllm auto_round:auto_gptq
|
||||
"OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ## auto_round:auto_gptq
|
||||
"Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", ## auto_round:auto_awq
|
||||
)
|
||||
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B"
|
||||
DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST = "Barrrrry/DeepSeek-R1-W4AFP8"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user