Files
enginex-mlu590-vllm/vllm_mlu/model_executor/layers/quantization/fp8.py
2026-04-24 09:58:03 +08:00

754 lines
26 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import functools
from functools import partial
import importlib.util
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from typing import Any, Dict, List, Optional, Callable
from vllm import envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.fp8 import (
get_flashinfer_moe_backend,
ACTIVATION_SCHEMES,
Fp8Config,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
validate_fp8_block_shape
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
maybe_create_device_identity, Fp8LinearOp)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable marlin for MLU backend.
'''
if current_platform.is_rocm() or current_platform.is_out_of_tree():
use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm():
logger.warning_once("DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and block_quant
):
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
Fp8Config____init____org = Fp8Config.__init__
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
activation_quant_method: Optional[str] = None,
weight_quant_method: Optional[str] = None,
) -> None:
super(Fp8Config, self).__init__()
Fp8Config____init____org(
self,
is_checkpoint_fp8_serialized,
activation_scheme,
ignored_layers,
weight_block_size
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add class members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.activation_quant_method = activation_quant_method
self.weight_quant_method = weight_quant_method
assert (self.weight_block_size or \
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
"input dynamic per-token weight per-channel quantization yet."
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
cls, config: Dict[str, Any]
) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
activation_quant_method = cls.get_from_keys_or(config,
["activation_quant_method"],
'per_token')
weight_quant_method = cls.get_from_keys_or(config,
["weight_quant_method"],
None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
activation_quant_method=activation_quant_method,
weight_quant_method=weight_quant_method)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
tp_group = extra_weight_attrs.get("tp_group", None)
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
'''
==================
End of MLU Hijack
==================
'''
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Support weight per channel quantization.
@brief: Add tp_group to enable custom split.
'''
if self.weight_per_channel:
scale = ChannelQuantScaleParameter(
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
'''
==================
End of MLU Hijack
==================
'''
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
self,
quant_config: Fp8Config
):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False
# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
if self.weight_per_channel and self.activation_per_token:
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
)
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
self,
layer: Module,
) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: For dynamic activation and channel-wise weight quantization,
additional processing is not needed.
'''
if (self.quant_config.is_checkpoint_fp8_serialized
and self.weight_per_channel
and self.quant_config.activation_scheme == "dynamic"):
return
'''
==================
End of MLU Hijack
==================
'''
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert residual is None, "Fp8Linear residual is not supported yet."
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
assert self.weight_block_size is not None
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use activation per token quantization based on quantization config.
'''
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
weight_per_channel=self.weight_per_channel,
activation_per_token=self.activation_per_token)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
self,
quant_config: Fp8Config,
layer: torch.nn.Module
):
super(Fp8MoEMethod, self).__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: In mlu, always set self.use_marlin as False.
'''
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
'''
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
if scoring_func == "softmax":
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
route_scale=routed_scaling_factor,
)
elif scoring_func == "sigmoid":
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
routed_scaling_factor,
e_score_correction_bias,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# gen_idx
ori_input_shape = x.shape
x = x.reshape(-1, x.size(-1))
router_logits = router_logits.reshape(-1, router_logits.size(-1))
expert_num = router_logits.size(-1)
tokens_num = x.size(0)
expert_size = layer.w13_weight.size(0)
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
topk_ids, expert_num
)
expand_hidden_states = mlu_ops.moe_expand_input(
x, expand_idx, cumsum_token_count, 0, expert_size
)
quant_input, input_scale = _fp8_quantize(
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm1_out = mlu_ops.smooth_quant_group_gemm(
quant_input,
layer.w13_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=input_scale.T.contiguous(),
b_scale=layer.w13_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
act_out_quantize, act_out_scale = _fp8_quantize(
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm2_out = mlu_ops.smooth_quant_group_gemm(
act_out_quantize,
layer.w2_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=act_out_scale.T.contiguous(),
b_scale=layer.w2_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
output = mlu_ops.moe_combine_result(
gemm2_out,
topk_weights,
combine_idx,
residual=None,
cusum_token_count=cumsum_token_count,
start_expert_id=0,
expert_size=expert_size,
bias=None,
)
return output.view(ori_input_shape)
"""
==================
End of MLU Hijack
==================
"""
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.from_config,
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.create_weights,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.process_weights_after_loading,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
)