Refactor: move all quantization-related code to srt/layer/quantization (#7989)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
@@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
is_sm100_supported,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
@@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"kv_cache_quant_algo"
|
||||
@@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if self.exclude_modules and any(
|
||||
module in prefix
|
||||
or (
|
||||
@@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
# Add MoE support
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self)
|
||||
|
||||
@@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class ModelOptFp8MoEMethod:
|
||||
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for ModelOpt FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
||||
|
||||
@@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod:
|
||||
quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
Dynamic class composition pattern.
|
||||
|
||||
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
||||
at runtime while avoiding circular import issues.
|
||||
"""
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
@@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if not quant_method in ["FP8", "NVFP4"]:
|
||||
@@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoEMethod:
|
||||
class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
"""
|
||||
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
||||
Args:
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp4Config):
|
||||
self.quant_config = quant_config
|
||||
if not is_sm100_supported():
|
||||
|
||||
Reference in New Issue
Block a user