fix: support fb fp8 (#9462)
This commit is contained in:
@@ -16,7 +16,6 @@ try:
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config,
|
||||
@@ -37,9 +36,9 @@ except ImportError as e:
|
||||
|
||||
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
||||
ExpertsInt8Config
|
||||
) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
|
||||
Int8TpuConfig
|
||||
) = DummyConfig
|
||||
) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
|
||||
DummyConfig
|
||||
)
|
||||
|
||||
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
||||
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
ModelOptFp4Config,
|
||||
@@ -85,6 +85,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"qoq": QoQConfig,
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
}
|
||||
|
||||
|
||||
@@ -109,7 +110,6 @@ VLLM_QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -16,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
can_auto_enable_marlin_fp8,
|
||||
@@ -28,7 +29,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_fp8_fnuz
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
@@ -88,6 +89,9 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user