fix: support fb fp8 (#9462)
This commit is contained in:
@@ -16,7 +16,6 @@ try:
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
|
||||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQMarlin24Config,
|
GPTQMarlin24Config,
|
||||||
@@ -37,9 +36,9 @@ except ImportError as e:
|
|||||||
|
|
||||||
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
||||||
ExpertsInt8Config
|
ExpertsInt8Config
|
||||||
) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
|
) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
|
||||||
Int8TpuConfig
|
DummyConfig
|
||||||
) = DummyConfig
|
)
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
||||||
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
|
from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
|
||||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||||
ModelOptFp4Config,
|
ModelOptFp4Config,
|
||||||
@@ -85,6 +85,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"qoq": QoQConfig,
|
"qoq": QoQConfig,
|
||||||
"w4afp8": W4AFp8Config,
|
"w4afp8": W4AFp8Config,
|
||||||
"petit_nvfp4": PetitNvFp4Config,
|
"petit_nvfp4": PetitNvFp4Config,
|
||||||
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -109,7 +110,6 @@ VLLM_QUANTIZATION_METHODS = {
|
|||||||
"aqlm": AQLMConfig,
|
"aqlm": AQLMConfig,
|
||||||
"deepspeedfp": DeepSpeedFPConfig,
|
"deepspeedfp": DeepSpeedFPConfig,
|
||||||
"tpu_int8": Int8TpuConfig,
|
"tpu_int8": Int8TpuConfig,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
|
||||||
"marlin": MarlinConfig,
|
"marlin": MarlinConfig,
|
||||||
"gguf": GGUFConfig,
|
"gguf": GGUFConfig,
|
||||||
"gptq_marlin_24": GPTQMarlin24Config,
|
"gptq_marlin_24": GPTQMarlin24Config,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
@@ -16,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
can_auto_enable_marlin_fp8,
|
can_auto_enable_marlin_fp8,
|
||||||
@@ -28,7 +29,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
|
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
|
||||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_fp8_fnuz
|
from sglang.srt.utils import get_bool_env_var, is_cuda
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
@@ -88,6 +89,9 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
return FBGEMMFp8LinearMethod(self)
|
return FBGEMMFp8LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user