Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
|
||||
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
|
||||
raise NotImplementedError(
|
||||
"EP parallelism is not supported with TRTLLM"
|
||||
"per-tensor FP8 quantization."
|
||||
)
|
||||
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# Make additional scales for per-tensor interface.
|
||||
if self.quant_config.is_per_tensor:
|
||||
w1_scale = self.quant_config.w1_scale
|
||||
assert w1_scale is not None
|
||||
a1_scale = self.quant_config.a1_scale
|
||||
assert a1_scale is not None
|
||||
w2_scale = self.quant_config.w2_scale
|
||||
assert w2_scale is not None
|
||||
a2_scale = self.quant_config.a2_scale
|
||||
assert a2_scale is not None
|
||||
|
||||
self._g1_alphas = (w1_scale * a1_scale).squeeze()
|
||||
self._g2_alphas = (w2_scale * a2_scale).squeeze()
|
||||
self._g1_scale_c = (
|
||||
self._g1_alphas / self.quant_config.a2_scale
|
||||
if moe_config.is_act_and_mul
|
||||
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
# Add check flashinfer trtllm is available
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
or moe_parallel_config.use_naive_all2all_kernels
|
||||
) and not moe_parallel_config.enable_eplb
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_per_block(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
|
||||
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# Kernel requires transposed hidden state scales
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
a1q_scale_t = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale_t,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
def _apply_per_tensor(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# Confirm supported activation function.
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
activation_type = ActivationType(activation_to_flashinfer_int(activation))
|
||||
|
||||
# Confirm Llama-4 routing is proper.
|
||||
if self.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert apply_router_weight_on_input
|
||||
else:
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=w1,
|
||||
output1_scales_scalar=self._g1_scale_c,
|
||||
output1_scales_gate_scalar=self._g1_alphas,
|
||||
gemm2_weights=w2,
|
||||
output2_scales_scalar=self._g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=num_expert_group or 0,
|
||||
topk_group=topk_group or 0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=self.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.block_shape is not None:
|
||||
return self._apply_per_block(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif self.quant_config.is_per_tensor:
|
||||
return self._apply_per_tensor(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only per-block and per-tensor quantization are supported in "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsBase:
|
||||
"""
|
||||
NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
self.moe_config = moe_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.routing_method_type = self.moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
assert self.quant_config.g1_alphas is not None
|
||||
assert self.quant_config.a2_gscale is not None
|
||||
if moe_config.is_act_and_mul:
|
||||
# g1_alpha_s = a13_scale * w13_scale_2
|
||||
# a2_gscale = (1 / a2_scale)
|
||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
self.g1_scale_c = (
|
||||
torch.ones_like(self.quant_config.a1_gscale)
|
||||
* self.quant_config.a2_gscale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE (i.e. Nemotron-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""Requires hidden dim to be multiple of 512."""
|
||||
return hidden_dim % 512 == 0
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
Modular version of the implementation (just the experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation supports all parallel configs."""
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
|
||||
# Hidden states are Nvfp4, packed into int8 dtype, so we
|
||||
# need to multiply K by 2 to get the output shape right.
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# trtllm_fp4_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
import vllm.utils.flashinfer as fi_utils
|
||||
|
||||
if fi_utils._is_fi_autotuning:
|
||||
return hidden_states
|
||||
|
||||
# Invoke kernel.
|
||||
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
activation_type=activation_to_flashinfer_int(activation),
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsMonolithic(
|
||||
TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
|
||||
):
|
||||
"""
|
||||
Monolithic version of the kernel (router + experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation should be used for the Dp/Ep or EPLB case."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method_type: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# NOTE(rob): this is a conservative list.
|
||||
return routing_method_type in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
assert (
|
||||
apply_router_weight_on_input
|
||||
and self.routing_method_type == RoutingMethodType.Llama4
|
||||
) or (
|
||||
not apply_router_weight_on_input
|
||||
and self.routing_method_type != RoutingMethodType.Llama4
|
||||
)
|
||||
|
||||
# Prepare routing bias into kernel format.
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Invoke kernel.
|
||||
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
Reference in New Issue
Block a user