diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index d001bb646..ff3c2b148 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -16,7 +16,6 @@ try: ) from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config - from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config, @@ -37,9 +36,9 @@ except ImportError as e: AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( ExpertsInt8Config - ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = ( - Int8TpuConfig - ) = DummyConfig + ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = ( + DummyConfig + ) from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig @@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig, ) from sglang.srt.layers.quantization.fp8 import Fp8Config +from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, @@ -85,6 +85,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "qoq": QoQConfig, "w4afp8": W4AFp8Config, "petit_nvfp4": PetitNvFp4Config, + "fbgemm_fp8": FBGEMMFp8Config, } @@ -109,7 +110,6 @@ VLLM_QUANTIZATION_METHODS = { "aqlm": AQLMConfig, "deepspeedfp": DeepSpeedFPConfig, "tpu_int8": Int8TpuConfig, - "fbgemm_fp8": FBGEMMFp8Config, "marlin": MarlinConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, diff --git a/python/sglang/srt/layers/quantization/fpgemm_fp8.py b/python/sglang/srt/layers/quantization/fpgemm_fp8.py index fcfba7b09..5a78626ff 100644 --- a/python/sglang/srt/layers/quantization/fpgemm_fp8.py +++ b/python/sglang/srt/layers/quantization/fpgemm_fp8.py @@ -8,7 +8,7 @@ import torch from torch.nn import Module from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -16,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, can_auto_enable_marlin_fp8, @@ -28,7 +29,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import ( ) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter -from sglang.srt.utils import get_bool_env_var, is_cuda, is_fp8_fnuz +from sglang.srt.utils import get_bool_env_var, is_cuda _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() @@ -88,6 +89,9 @@ class FBGEMMFp8Config(QuantizationConfig): return FBGEMMFp8LinearMethod(self) return None + def get_scaled_act_names(self) -> List[str]: + return [] + class FBGEMMFp8LinearMethod(LinearMethodBase):