Refactor: move all quantization-related code to srt/layer/quantization (#7989)

This commit is contained in:
Cheng Wan
2025-07-17 00:47:07 -07:00
committed by GitHub
parent 02404a1e35
commit 49b8777460
22 changed files with 1095 additions and 1175 deletions

View File

@@ -1,4 +1,5 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
@@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
is_sm100_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
is_layer_skipped,
@@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig):
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
"kv_cache_quant_algo"
@@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if self.exclude_modules and any(
module in prefix
or (
@@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig):
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
# Add MoE support
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
@@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config)
class ModelOptFp8MoEMethod:
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
@@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod:
quant_config: The ModelOpt quantization config.
"""
def __new__(cls, *args, **kwargs):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
@@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig):
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if not quant_method in ["FP8", "NVFP4"]:
@@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
@@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
return out.view(*output_shape)
class ModelOptNvFp4FusedMoEMethod:
class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
Args:
quant_config: NVFP4 Quant Config
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config
if not is_sm100_supported():