Refactor: move all quantization-related code to srt/layer/quantization (#7989)

This commit is contained in:
Cheng Wan
2025-07-17 00:47:07 -07:00
committed by GitHub
parent 02404a1e35
commit 49b8777460
22 changed files with 1095 additions and 1175 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
@@ -5,12 +7,13 @@ import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
return []
class W4AFp8MoEMethod:
class W4AFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config