feat: fused_moe fp8 monkey patch (#2174)
This commit is contained in:
@@ -1,18 +1,19 @@
|
|||||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||||
|
|
||||||
from typing import Dict, Type
|
from typing import Callable, Dict, Optional, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
||||||
@@ -30,8 +31,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"tpu_int8": Int8TpuConfig,
|
"tpu_int8": Int8TpuConfig,
|
||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
# The order of gptq methods is important for config.py iteration over
|
|
||||||
# override_quantization_method(..)
|
|
||||||
"marlin": MarlinConfig,
|
"marlin": MarlinConfig,
|
||||||
"gguf": GGUFConfig,
|
"gguf": GGUFConfig,
|
||||||
"gptq_marlin_24": GPTQMarlin24Config,
|
"gptq_marlin_24": GPTQMarlin24Config,
|
||||||
@@ -47,33 +46,70 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||||
if quantization not in QUANTIZATION_METHODS:
|
if quantization not in QUANTIZATION_METHODS:
|
||||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
raise ValueError(
|
||||||
|
f"Invalid quantization method: {quantization}. "
|
||||||
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
||||||
|
)
|
||||||
return QUANTIZATION_METHODS[quantization]
|
return QUANTIZATION_METHODS[quantization]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
def fp8_moe_apply(
|
||||||
"QuantizationConfig",
|
self,
|
||||||
"get_quantization_config",
|
layer: torch.nn.Module,
|
||||||
"QUANTIZATION_METHODS",
|
x: torch.Tensor,
|
||||||
]
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Enhanced apply method for FP8 MoE."""
|
||||||
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
|
# Expert selection
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert fusion with FP8 quantization
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fp8_get_quant_method(self, layer, prefix):
|
def fp8_get_quant_method(self, layer, prefix):
|
||||||
|
"""Enhanced get_quant_method for FP8 config."""
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
Fp8LinearMethod,
|
|
||||||
Fp8MoEMethod,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||||
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
|
||||||
|
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
@@ -81,4 +117,18 @@ def fp8_get_quant_method(self, layer, prefix):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
def apply_monkey_patches():
|
||||||
|
"""Apply all monkey patches in one place."""
|
||||||
|
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
||||||
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply patches when module is imported
|
||||||
|
apply_monkey_patches()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QuantizationConfig",
|
||||||
|
"get_quantization_config",
|
||||||
|
"QUANTIZATION_METHODS",
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user