Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -22,12 +22,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
@@ -61,9 +62,10 @@ __all__ = [
|
||||
"MoEActivation",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEExpertsModular",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"FusedMoEPrepareAndFinalizeModular",
|
||||
"GateLinear",
|
||||
"RoutingMethodType",
|
||||
"SharedFusedMoE",
|
||||
"ZeroExpertFusedMoE",
|
||||
@@ -137,4 +139,4 @@ else:
|
||||
raise NotImplementedError(f"{method} is not implemented as lack of triton.")
|
||||
|
||||
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
@@ -6,8 +6,7 @@ from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class MoEActivation(Enum):
|
||||
@@ -114,14 +113,11 @@ def apply_moe_activation(
|
||||
|
||||
# Activations with gated multiplication (gate × activation(up))
|
||||
if activation == MoEActivation.SILU:
|
||||
# torch.ops._C.silu_and_mul(output, input)
|
||||
silu_and_mul(output, input)
|
||||
ops.silu_and_mul(output, input)
|
||||
elif activation == MoEActivation.GELU:
|
||||
# torch.ops._C.gelu_and_mul(output, input)
|
||||
gelu_and_mul(output, input)
|
||||
ops.gelu_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUOAI:
|
||||
# torch.ops._C.swigluoai_and_mul(output, input)
|
||||
swigluoai_and_mul(output, input)
|
||||
ops.swigluoai_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUSTEP:
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,20 +21,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNaiveEP,
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
make_moe_prepare_and_finalize_naive_dp_ep,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes,
|
||||
)
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (
|
||||
@@ -81,6 +77,7 @@ def maybe_make_prepare_finalize(
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
allow_new_interface: bool = False,
|
||||
use_monolithic: bool = False,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
# NOTE(rob): we are migrating each quant_method to hold the MK
|
||||
# in all cases. The allow_new_interface=False flag allow us to fall
|
||||
@@ -106,65 +103,25 @@ def maybe_make_prepare_finalize(
|
||||
"Detected DP deployment with no --enable-expert-parallel. "
|
||||
"Falling back to AllGather+ReduceScatter dispatch/combine."
|
||||
)
|
||||
return MoEPrepareAndFinalizeNaiveEP(
|
||||
return make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=(
|
||||
get_ep_group().device_communicator.all2all_manager.world_size
|
||||
),
|
||||
use_monolithic=use_monolithic,
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=all2all_manager.rank,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
num_dispatchers = (
|
||||
all2all_manager.world_size // all2all_manager.tp_group.world_size
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
if not all2all_manager.internode:
|
||||
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_local_experts=moe.num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
if moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
|
||||
all_to_all_args = dict()
|
||||
all_to_all_args: dict[str, Any] = dict()
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
||||
handle,
|
||||
@@ -246,8 +203,9 @@ def maybe_make_prepare_finalize(
|
||||
)
|
||||
|
||||
elif moe.use_naive_all2all_kernels and allow_new_interface:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
|
||||
prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
use_monolithic=use_monolithic,
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -228,6 +228,7 @@ class FusedMoEQuantConfig:
|
||||
_a2: FusedMoEQuantDesc
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
@@ -475,6 +476,7 @@ class FusedMoEQuantConfig:
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
@@ -504,6 +506,7 @@ class FusedMoEQuantConfig:
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
|
||||
"""
|
||||
assert not isinstance(quant_dtype, str) or quant_dtype in {
|
||||
"nvfp4",
|
||||
@@ -536,6 +539,7 @@ class FusedMoEQuantConfig:
|
||||
_w2=FusedMoEQuantDesc(
|
||||
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
|
||||
),
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config(
|
||||
w2_scale: torch.Tensor,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config(
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
|
||||
@@ -939,10 +945,6 @@ class FusedMoEParallelConfig:
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "pplx"
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (
|
||||
@@ -962,7 +964,7 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_batched_activation_format(self):
|
||||
return self.use_deepep_ll_kernels or self.use_pplx_kernels
|
||||
return self.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_naive_all2all_kernels(self):
|
||||
@@ -1221,10 +1223,6 @@ class FusedMoEConfig:
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.6.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.6.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8(
|
||||
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
|
||||
ops.get_cutlass_pplx_moe_mm_data(
|
||||
ops.get_cutlass_batched_moe_mm_data(
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8(
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4(
|
||||
return
|
||||
|
||||
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
"""CUTLASS FP4 fused MoE expert implementation."""
|
||||
|
||||
@property
|
||||
@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype | None,
|
||||
@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8(
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fn = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
CutlassExpertsW4A8Fp8(
|
||||
out_dtype=a.dtype,
|
||||
a_strides1=a_strides1,
|
||||
@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8(
|
||||
quant_config=quant_config,
|
||||
group_size=group_size,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fn(
|
||||
return fn.apply(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
|
||||
@@ -113,7 +113,7 @@ def _valid_deep_gemm(
|
||||
return True
|
||||
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
"""DeepGemm-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import (
|
||||
)
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||
"""
|
||||
@@ -123,7 +123,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
is_token_in_rank,
|
||||
event,
|
||||
) = self.buffer.get_dispatch_layout(
|
||||
topk_idx=rank_topk_ids,
|
||||
topk_idx=rank_topk_ids.long(),
|
||||
num_experts=num_experts,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
@@ -148,7 +148,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=dispatch_expert_num_tokens,
|
||||
topk_idx=rank_topk_ids,
|
||||
topk_idx=rank_topk_ids.long(),
|
||||
topk_weights=rank_topk_weights,
|
||||
# expert_alignment rounds the number of tokens per expert
|
||||
# to this value.
|
||||
@@ -169,7 +169,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
event,
|
||||
has_scales,
|
||||
token_data,
|
||||
expert_topk_ids,
|
||||
expert_topk_ids.int(),
|
||||
num_experts,
|
||||
expert_num_tokens_per_expert_list,
|
||||
expert_topk_weights,
|
||||
@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -49,7 +49,7 @@ def dequant_fp8(
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
|
||||
|
||||
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP low-latency kernels.
|
||||
"""
|
||||
@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# time. This setting is handled by post_init_setup.
|
||||
self.use_ue8m0_dispatch = False
|
||||
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
|
||||
if not fused_experts.supports_packed_ue8m0_act_scales():
|
||||
# Early exit.
|
||||
return
|
||||
@@ -297,12 +297,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
|
||||
a1,
|
||||
dispatch_topk_ids,
|
||||
dispatch_topk_ids.long(),
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
round_scale=self.use_ue8m0_dispatch,
|
||||
use_ue8m0=self.use_ue8m0_dispatch,
|
||||
# round_scale=self.use_ue8m0_dispatch,
|
||||
# use_ue8m0=self.use_ue8m0_dispatch,
|
||||
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||
**(
|
||||
dict(x_global_scale=qc_a1_gscale_or_scale)
|
||||
@@ -398,7 +398,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dbo_maybe_run_recv_hook()
|
||||
_, _, recv_hook = self.buffer.low_latency_combine(
|
||||
fused_expert_output,
|
||||
combine_topk_ids,
|
||||
combine_topk_ids.long(),
|
||||
combine_topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
|
||||
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
|
||||
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
|
||||
raise NotImplementedError(
|
||||
"EP parallelism is not supported with TRTLLM"
|
||||
"per-tensor FP8 quantization."
|
||||
)
|
||||
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# Make additional scales for per-tensor interface.
|
||||
if self.quant_config.is_per_tensor:
|
||||
w1_scale = self.quant_config.w1_scale
|
||||
assert w1_scale is not None
|
||||
a1_scale = self.quant_config.a1_scale
|
||||
assert a1_scale is not None
|
||||
w2_scale = self.quant_config.w2_scale
|
||||
assert w2_scale is not None
|
||||
a2_scale = self.quant_config.a2_scale
|
||||
assert a2_scale is not None
|
||||
|
||||
self._g1_alphas = (w1_scale * a1_scale).squeeze()
|
||||
self._g2_alphas = (w2_scale * a2_scale).squeeze()
|
||||
self._g1_scale_c = (
|
||||
self._g1_alphas / self.quant_config.a2_scale
|
||||
if moe_config.is_act_and_mul
|
||||
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
# Add check flashinfer trtllm is available
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
or moe_parallel_config.use_naive_all2all_kernels
|
||||
) and not moe_parallel_config.enable_eplb
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_per_block(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
|
||||
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# Kernel requires transposed hidden state scales
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
a1q_scale_t = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale_t,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
def _apply_per_tensor(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# Confirm supported activation function.
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
activation_type = ActivationType(activation_to_flashinfer_int(activation))
|
||||
|
||||
# Confirm Llama-4 routing is proper.
|
||||
if self.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert apply_router_weight_on_input
|
||||
else:
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=w1,
|
||||
output1_scales_scalar=self._g1_scale_c,
|
||||
output1_scales_gate_scalar=self._g1_alphas,
|
||||
gemm2_weights=w2,
|
||||
output2_scales_scalar=self._g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=num_expert_group or 0,
|
||||
topk_group=topk_group or 0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=self.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.block_shape is not None:
|
||||
return self._apply_per_block(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif self.quant_config.is_per_tensor:
|
||||
return self._apply_per_tensor(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only per-block and per-tensor quantization are supported in "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsBase:
|
||||
"""
|
||||
NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
self.moe_config = moe_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.routing_method_type = self.moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
assert self.quant_config.g1_alphas is not None
|
||||
assert self.quant_config.a2_gscale is not None
|
||||
if moe_config.is_act_and_mul:
|
||||
# g1_alpha_s = a13_scale * w13_scale_2
|
||||
# a2_gscale = (1 / a2_scale)
|
||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
self.g1_scale_c = (
|
||||
torch.ones_like(self.quant_config.a1_gscale)
|
||||
* self.quant_config.a2_gscale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE (i.e. Nemotron-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""Requires hidden dim to be multiple of 512."""
|
||||
return hidden_dim % 512 == 0
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
Modular version of the implementation (just the experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation supports all parallel configs."""
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
|
||||
# Hidden states are Nvfp4, packed into int8 dtype, so we
|
||||
# need to multiply K by 2 to get the output shape right.
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# trtllm_fp4_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
import vllm.utils.flashinfer as fi_utils
|
||||
|
||||
if fi_utils._is_fi_autotuning:
|
||||
return hidden_states
|
||||
|
||||
# Invoke kernel.
|
||||
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
activation_type=activation_to_flashinfer_int(activation),
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsMonolithic(
|
||||
TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
|
||||
):
|
||||
"""
|
||||
Monolithic version of the kernel (router + experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation should be used for the Dp/Ep or EPLB case."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method_type: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# NOTE(rob): this is a conservative list.
|
||||
return routing_method_type in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
assert (
|
||||
apply_router_weight_on_input
|
||||
and self.routing_method_type == RoutingMethodType.Llama4
|
||||
) or (
|
||||
not apply_router_weight_on_input
|
||||
and self.routing_method_type != RoutingMethodType.Llama4
|
||||
)
|
||||
|
||||
# Prepare routing bias into kernel format.
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Invoke kernel.
|
||||
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
|
||||
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
|
||||
"""Base class for runtime dispatching of expert implementations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
experts: mk.FusedMoEExpertsModular,
|
||||
fallback_experts: mk.FusedMoEExpertsModular,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=experts.moe_config, quant_config=experts.quant_config
|
||||
@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
"""
|
||||
Get the cls for the experts and fallback experts.
|
||||
@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
|
||||
@@ -18,7 +18,7 @@ def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
|
||||
ep_size,
|
||||
)
|
||||
|
||||
# Swizzle after the A2A if nvfp4.
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
# Swizzle after the A2A if MoE kernel expects swizzled scales.
|
||||
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
|
||||
if x_sf.element_size() == 1:
|
||||
x_sf = x_sf.view(torch.uint8)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
|
||||
return True
|
||||
|
||||
|
||||
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: mk.FusedMoEConfig,
|
||||
|
||||
@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
|
||||
def _supports_routing_method_bf16(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
|
||||
return not moe_parallel_config.enable_eplb
|
||||
|
||||
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
|
||||
def is_supported_config_trtllm_fp8(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(
|
||||
weight_key, activation_key, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif not _supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason(
|
||||
"float32 router_logits with non-DeepSeekV3 routing "
|
||||
f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def is_supported_config_trtllm_bf16(
|
||||
moe_config: FusedMoEConfig,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
|
||||
return True, None
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float | None = 1.0,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(x.dtype)
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
routing_method_type=routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
activation_type: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
routing_method_type=routing_method_type,
|
||||
# TODO: enum type Required for flashinfer==0.6.3, remove with update
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/2508
|
||||
activation_type=ActivationType(activation_type),
|
||||
)
|
||||
|
||||
|
||||
def fi_trtllm_fp8_per_tensor_moe_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
activation_type: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="fi_trtllm_fp8_per_tensor_moe",
|
||||
op_func=fi_trtllm_fp8_per_tensor_moe,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
|
||||
@@ -489,11 +489,11 @@ def invoke_moe_batched_triton_kernel(
|
||||
)
|
||||
|
||||
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
A reference prepare/finalize class that reorganizes the tokens into
|
||||
expert batched format, i.e. E x max_num_tokens x K. This is the format
|
||||
that the PPLX dispatch/combine kernels use.
|
||||
that the batched dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -645,10 +645,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
|
||||
|
||||
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
A reference MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
i.e. E x max_num_tokens x K. This is the format that the batched
|
||||
dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
@@ -877,10 +877,10 @@ def batched_moe_kernel_quantize_input(
|
||||
return A_q, A_q_scale
|
||||
|
||||
|
||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
A Triton based MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
i.e. E x max_num_tokens x K. This is the format that the batched
|
||||
dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
|
||||
@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
|
||||
return output
|
||||
|
||||
|
||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -53,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import ixformer.inference.functions as ixfops
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -575,56 +578,6 @@ def fused_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
ops.invoke_fused_moe_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
# ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
|
||||
return
|
||||
|
||||
|
||||
# NOTE(zyongye): we can remove all the wna16 kernel
|
||||
# once we drop off sm75 support
|
||||
def invoke_fused_moe_wna16_cuda_kernel(
|
||||
@@ -782,6 +735,7 @@ def invoke_fused_moe_triton_kernel(
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
@@ -799,7 +753,9 @@ def invoke_fused_moe_triton_kernel(
|
||||
):
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
|
||||
return
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
@@ -910,32 +866,6 @@ def dispatch_fused_moe_kernel(
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
invoke_fused_moe_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
B_bias
|
||||
)
|
||||
return
|
||||
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
|
||||
@@ -999,6 +929,7 @@ def dispatch_fused_moe_kernel(
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
@@ -1397,14 +1328,13 @@ def get_default_config(
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
# TODO
|
||||
numel = M * topk
|
||||
if numel <= 64:
|
||||
config["BLOCK_SIZE_M"] = 32
|
||||
config['BLOCK_SIZE_M'] = 32
|
||||
elif numel <= 1024:
|
||||
config["BLOCK_SIZE_M"] = 64
|
||||
config['BLOCK_SIZE_M'] = 64
|
||||
else:
|
||||
config["BLOCK_SIZE_M"] = 256
|
||||
config['BLOCK_SIZE_M'] = 256
|
||||
return config
|
||||
|
||||
|
||||
@@ -1424,14 +1354,12 @@ def try_get_optimal_moe_config(
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
E, _, N = w2_shape
|
||||
if dtype == "int4_w4a16":
|
||||
N = N * 2
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
# block_n = block_shape[0] if block_shape else 0
|
||||
# block_k = block_shape[1] if block_shape else 0
|
||||
# configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
|
||||
configs = None
|
||||
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
@@ -1560,13 +1488,12 @@ def outplace_fused_experts(
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
return fused_experts_impl_opt(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
@@ -1626,14 +1553,12 @@ direct_register_custom_op(
|
||||
|
||||
|
||||
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
# torch.ops.vllm.inplace_fused_experts(**kwargs)
|
||||
inplace_fused_experts(**kwargs)
|
||||
hidden_states = kwargs["hidden_states"]
|
||||
hidden_states = kwargs['hidden_states']
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
# return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
||||
return outplace_fused_experts(**kwargs)
|
||||
|
||||
|
||||
@@ -1661,7 +1586,6 @@ def fused_experts(
|
||||
"""Run fused MoE expert computation using Triton kernels."""
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
assert not inplace or not disable_inplace()
|
||||
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
@@ -1691,6 +1615,245 @@ def fused_experts(
|
||||
w2_bias=quant_config.w2_bias,
|
||||
)
|
||||
|
||||
def fused_experts_impl_opt(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
output: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
# check constraints
|
||||
if use_fp8_w8a8 or use_int8_w8a8 or use_int8_w8a16 or use_int4_w4a16 or w1_scale or \
|
||||
w2_scale or w1_zp or w2_zp or a1_scale or a2_scale:
|
||||
raise ValueError("Quantized MoE is not supported")
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
use_ep = expert_map is not None
|
||||
|
||||
# unsupported ep now
|
||||
if attn_metadata:
|
||||
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
|
||||
else:
|
||||
only_decode = False
|
||||
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
num_experts = w1.size(0)
|
||||
top_k = topk_weights.size(1)
|
||||
|
||||
if use_ep:
|
||||
local_num_experts = w1.size(0)
|
||||
start_eid = get_ep_group().device_group.rank() * local_num_experts
|
||||
end_eid = min((get_ep_group().device_group.rank() + 1) * local_num_experts, global_num_experts)
|
||||
hidden_size = hidden_states.shape[1]
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
expand_tokens,
|
||||
) = ixfops.moe_compute_token_index_ep(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=global_num_experts,
|
||||
start_expert_id=start_eid,
|
||||
end_expert_id=end_eid,
|
||||
)
|
||||
if expert_sizes_cpu.sum() == 0:
|
||||
return torch.zeros(
|
||||
(num_tokens, hidden_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
) = ixfops.moe_compute_token_index(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
|
||||
if only_decode:
|
||||
# expand + reorder
|
||||
hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=hidden_states,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# group gemm 1
|
||||
pt_output_1 = ixfops.moe_w16a16_group_gemv(
|
||||
input=hidden_states,
|
||||
weight=w1,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
dst_to_src=None,
|
||||
bias=w1_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# act
|
||||
if activation == "silu":
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
elif activation == "gelu":
|
||||
pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
|
||||
elif activation == "swigluoai":
|
||||
pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
|
||||
elif activation == "swiglustep":
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
output_dim = pt_output_1.shape[1]
|
||||
pt_output_2 = torch.empty(
|
||||
(num_tokens * top_k, output_dim//2),
|
||||
device=pt_output_1.device,
|
||||
dtype=pt_output_1.dtype,
|
||||
)
|
||||
swiglustep_and_mul_triton(pt_output_2, pt_output_1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
# group gemm 2 + reorder
|
||||
pt_output_3 = ixfops.moe_w16a16_group_gemv(
|
||||
input=pt_output_2,
|
||||
weight=w2,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
dst_to_src=sorted_token_ids,
|
||||
bias=w2_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
)
|
||||
|
||||
else:
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
# expand + reorder
|
||||
hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=hidden_states,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
# group gemm 1
|
||||
pt_output_1 = ixfops.moe_w16a16_group_gemm(
|
||||
input=hidden_states,
|
||||
weight=w1,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
dst_to_src=None,
|
||||
bias=w1_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# act
|
||||
if activation == "silu":
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
elif activation == "gelu":
|
||||
pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
|
||||
elif activation == "swigluoai":
|
||||
pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
|
||||
elif activation == "swiglustep":
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
output_dim = pt_output_1.shape[1]
|
||||
pt_output_2 = torch.empty(
|
||||
(num_tokens * top_k, output_dim//2),
|
||||
device=pt_output_1.device,
|
||||
dtype=pt_output_1.dtype,
|
||||
)
|
||||
swiglustep_and_mul_triton(pt_output_2, pt_output_1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * top_k, hidden_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
# group gemm 2 + reorder
|
||||
pt_output_3 = ixfops.moe_w16a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=w2,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="TN",
|
||||
bias=w2_bias,
|
||||
output=pt_output_3,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
reduce_mask = src_to_dst == -1
|
||||
if output != None:
|
||||
ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
output=output,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
# group gemm 2 + reorder
|
||||
pt_output_3 = ixfops.moe_w16a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=w2,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
dst_to_src=sorted_token_ids,
|
||||
bias=w2_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
)
|
||||
|
||||
if output == None:
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
@@ -1825,7 +1988,7 @@ def fused_experts_impl(
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
|
||||
activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
|
||||
N, activation_enum
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
@@ -1910,28 +2073,28 @@ def fused_experts_impl(
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
# activates only a small fraction of total experts
|
||||
SPARSITY_FACTOR = 4
|
||||
# block quantized code path is not implemented yet.
|
||||
naive_block_assignment = (
|
||||
expert_map is None
|
||||
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
and not (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
)
|
||||
)
|
||||
# # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
# # activates only a small fraction of total experts
|
||||
# SPARSITY_FACTOR = 4
|
||||
# # block quantized code path is not implemented yet.
|
||||
# naive_block_assignment = (
|
||||
# expert_map is None
|
||||
# and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
# and not (
|
||||
# (use_int8_w8a16 or use_int4_w4a16)
|
||||
# and block_shape is not None
|
||||
# and block_shape[1] > 0
|
||||
# )
|
||||
# )
|
||||
|
||||
# if not naive_block_assignment:
|
||||
# sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
# curr_topk_ids,
|
||||
# config["BLOCK_SIZE_M"],
|
||||
# global_num_experts,
|
||||
# expert_map,
|
||||
# ignore_invalid_experts=True,
|
||||
# )
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
# else:
|
||||
# max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
|
||||
# expert_ids = curr_topk_ids.view(-1)
|
||||
@@ -1941,14 +2104,6 @@ def fused_experts_impl(
|
||||
# num_tokens_post_padded.fill_(max_num_tokens_padded)
|
||||
# sorted_token_ids = None
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
dispatch_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
@@ -2015,20 +2170,14 @@ def fused_experts_impl(
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
# ops.moe_sum(
|
||||
# intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
# out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
# )
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class TritonExperts(mk.FusedMoEExpertsModular):
|
||||
"""Triton-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
@@ -2091,8 +2240,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
# return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
return True
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
@@ -2138,157 +2286,31 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.size(-1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
|
||||
)
|
||||
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert hidden_states.dim() == 2
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
]
|
||||
|
||||
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
self.quant_config.config_name(hidden_states.dtype),
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif (
|
||||
hidden_states.dtype == torch.float8_e4m3fn
|
||||
or hidden_states.dtype == torch.float8_e4m3fnuz
|
||||
):
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
# Note that the output tensor might be in workspace1
|
||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||
cache2_dim = self.adjust_N_for_activation(N, activation)
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace13, (num_tokens * top_k_num, cache2_dim)
|
||||
)
|
||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
)
|
||||
|
||||
invoke_fused_moe_triton_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
self.w1_scale,
|
||||
None, # topk_weights
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False, # mul_routed_weights
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=self.w1_bias,
|
||||
)
|
||||
|
||||
self.activation(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
a2q_scale: torch.Tensor | None = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2,
|
||||
a2_scale,
|
||||
self.quant_dtype,
|
||||
self.per_act_token_quant,
|
||||
self.block_shape,
|
||||
)
|
||||
|
||||
# invoke_fused_moe_triton_kernel(
|
||||
# qintermediate_cache2,
|
||||
# w2,
|
||||
# intermediate_cache3,
|
||||
# a2q_scale,
|
||||
# self.w2_scale,
|
||||
# topk_weights,
|
||||
# sorted_token_ids,
|
||||
# expert_ids,
|
||||
# num_tokens_post_padded,
|
||||
# not apply_router_weight_on_input,
|
||||
# 1,
|
||||
# config,
|
||||
# compute_type=compute_type,
|
||||
# use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
# use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
# use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
# use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
# per_channel_quant=self.per_act_token_quant,
|
||||
# block_shape=self.block_shape,
|
||||
# B_bias=self.w2_bias,
|
||||
# )
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
self.w2_scale,
|
||||
self.w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=self.w2_bias,
|
||||
)
|
||||
|
||||
# separate function is required for MoE + LoRA
|
||||
self.moe_sum(intermediate_cache3, output)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
ops.moe_sum(input, output)
|
||||
fused_experts_impl_opt(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
self.quant_config.use_fp8_w8a8,
|
||||
self.quant_config.use_int8_w8a8,
|
||||
self.quant_config.use_int8_w8a16,
|
||||
self.quant_config.use_int4_w4a16,
|
||||
self.quant_config.ocp_mx_scheme,
|
||||
self.quant_config.per_act_token_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
self.quant_config.w1_scale,
|
||||
self.quant_config.w2_scale,
|
||||
self.quant_config.w1_zp,
|
||||
self.quant_config.w2_zp,
|
||||
self.quant_config.a1_scale,
|
||||
self.quant_config.a2_scale,
|
||||
self.quant_config.block_shape,
|
||||
self.quant_config.w1_bias,
|
||||
self.quant_config.w2_bias,
|
||||
output)
|
||||
|
||||
|
||||
class TritonWNA16Experts(TritonExperts):
|
||||
|
||||
@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
super().__init__()
|
||||
self.moe: FusedMoEConfig = moe
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
@property
|
||||
def supports_internal_mk(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None
|
||||
return self.moe_kernel is not None
|
||||
|
||||
@property
|
||||
def mk_owns_shared_expert(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
|
||||
return (
|
||||
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
from .all2all_utils import maybe_make_prepare_finalize
|
||||
|
||||
return maybe_make_prepare_finalize(
|
||||
pf = maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config, routing_tables
|
||||
)
|
||||
assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
|
||||
return pf
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
) -> FusedMoEExpertsModular:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize"
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
|
||||
raise NotImplementedError(
|
||||
"Method 'prepare_dp_allgather_tensor' is not implemented in "
|
||||
f"{self.__class__.__name__}."
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.moe_mk is not None:
|
||||
return self.moe_mk.prepare_finalize.topk_indices_dtype()
|
||||
if self.moe_kernel is not None:
|
||||
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return False
|
||||
if self.moe_kernel is None:
|
||||
if hasattr(self, "experts_cls"):
|
||||
return self.experts_cls.is_monolithic()
|
||||
else:
|
||||
return False
|
||||
return self.moe_kernel.is_monolithic
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEKernel,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
# --8<-- [end:modular_fused_moe]
|
||||
|
||||
def __init__(
|
||||
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
|
||||
self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
|
||||
):
|
||||
super().__init__(old_quant_method.moe)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.moe_mk = experts
|
||||
self.moe_kernel = moe_kernel
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not self.moe_mk.supports_expert_map(),
|
||||
not self.moe_kernel.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||
@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def make(
|
||||
moe_layer: torch.nn.Module,
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
inplace: bool = False,
|
||||
) -> "FusedMoEModularMethod":
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
FusedMoEModularKernel(
|
||||
FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -178,7 +179,40 @@ def triton_kernel_moe_forward(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
quant_config is not None
|
||||
and quant_config.use_mxfp4_w4a8
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
|
||||
|
||||
routing_data, gather_idx, scatter_idx = aiter_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
return triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation.value,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
unpadded_N_w1=unpadded_N_w1,
|
||||
unpadded_K_w1=unpadded_K_w1,
|
||||
unpadded_N_w2=unpadded_N_w2,
|
||||
unpadded_K_w2=unpadded_K_w2,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
# With expert parallelism, legacy_routing produces routing data
|
||||
# using global expert IDs which don't correspond to local weight
|
||||
@@ -210,6 +244,9 @@ def triton_kernel_moe_forward(
|
||||
effective_global_num_experts = global_num_experts
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
effective_quant_config = (
|
||||
quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
output,
|
||||
@@ -221,7 +258,7 @@ def triton_kernel_moe_forward(
|
||||
scatter_idx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
quant_config=effective_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=effective_global_num_experts,
|
||||
expert_map=effective_expert_map,
|
||||
@@ -252,8 +289,7 @@ def triton_kernel_fused_experts(
|
||||
assert activation == MoEActivation.SWIGLUOAI, (
|
||||
"Only SWIGLUOAI activation is supported"
|
||||
)
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
assert quant_config is not None
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
@@ -330,6 +366,98 @@ def triton_kernel_fused_experts(
|
||||
return output_tensor
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
output_tensor: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
assert quant_config is not None
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
|
||||
from aiter.ops.triton.quant_moe import downcast_to_static_fp8
|
||||
|
||||
assert quant_config.w1_precision is not None, (
|
||||
"w1_precision in quant config can't be None"
|
||||
)
|
||||
assert quant_config.w2_precision is not None, (
|
||||
"w2_precision in quant config can't be None"
|
||||
)
|
||||
|
||||
hidden_states = downcast_to_static_fp8(
|
||||
hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
|
||||
)
|
||||
|
||||
intermediate_cache1 = moe_gemm_a8w4(
|
||||
hidden_states,
|
||||
w1.storage.data,
|
||||
None,
|
||||
quant_config.w1_precision.weight_scale.storage.data,
|
||||
quant_config.w1_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
out_dtype=torch.float8_e4m3fn,
|
||||
apply_swiglu=True,
|
||||
alpha=swiglu_alpha,
|
||||
limit=swiglu_limit,
|
||||
unpadded_N=unpadded_N_w1,
|
||||
unpadded_K=unpadded_K_w1,
|
||||
)
|
||||
|
||||
intermediate_cache3 = moe_gemm_a8w4(
|
||||
intermediate_cache1,
|
||||
w2.storage.data,
|
||||
None,
|
||||
quant_config.w2_precision.weight_scale.storage.data,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
None,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
unpadded_N=unpadded_N_w2,
|
||||
unpadded_K=unpadded_K_w2,
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def make_routing_data(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@@ -383,7 +511,7 @@ def make_routing_data(
|
||||
return routing_data, gather_indx, scatter_indx
|
||||
|
||||
|
||||
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
raise NotImplementedError(
|
||||
@@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if self.quant_config is None:
|
||||
self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from collections.abc import Callable, Iterable
|
||||
from enum import Enum
|
||||
from typing import Literal, cast, get_args, overload
|
||||
|
||||
import ast, re
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -54,10 +54,14 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.model_executor.layers.utils import (
|
||||
parse_opt_exclude_layers,
|
||||
weight_quant_l1,
|
||||
weight_quant_l2,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
@@ -333,6 +337,7 @@ class FusedMoE(CustomOp):
|
||||
gate: torch.nn.Module | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
routed_input_transform: torch.nn.Module | None = None,
|
||||
fused_shared_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -483,6 +488,8 @@ class FusedMoE(CustomOp):
|
||||
(expert_mask == 0) | (expert_mask == 1)
|
||||
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
@@ -526,16 +533,18 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# Round up hidden size before creating moe_config.
|
||||
# This way moe_config is created with the correct hidden_size from the start.
|
||||
unpadded_hidden_size = hidden_size
|
||||
self.model_type = (
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
)
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size=hidden_size,
|
||||
act_dtype=moe_in_dtype,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
model_type=(
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
),
|
||||
model_type=self.model_type,
|
||||
is_mxfp4_quant=(
|
||||
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
|
||||
),
|
||||
@@ -581,14 +590,27 @@ class FusedMoE(CustomOp):
|
||||
"""
|
||||
quant_method = None
|
||||
if self.quant_config is not None:
|
||||
self.opt_level = 0
|
||||
quant_method = self.quant_config.get_quant_method(self, prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsL1OptMoEMethod, CompressedTensorsL2OptMoEMethod)
|
||||
if self.opt_level == 1:
|
||||
quant_method = CompressedTensorsL1OptMoEMethod(self.moe_config)
|
||||
elif self.opt_level == 2:
|
||||
quant_method = CompressedTensorsL2OptMoEMethod(self.moe_config)
|
||||
else:
|
||||
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
return quant_method
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
self.opt_level = envs.VLLM_MOE_OPT_LEVEL
|
||||
if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, prefix):
|
||||
self.opt_flag = False
|
||||
logger.info(f"Excluding layer {prefix} from optimization")
|
||||
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
|
||||
@@ -611,6 +633,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
@@ -625,6 +648,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.base_quant_method = self.quant_method
|
||||
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
@@ -638,7 +662,10 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
if fused_shared_output:
|
||||
assert self.use_ep == False, "Fused shared output is only supported when EP is disabled."
|
||||
assert shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
|
||||
self.fused_shared_output = fused_shared_output
|
||||
self.runner = self._init_runner()
|
||||
|
||||
def _init_runner(self):
|
||||
@@ -655,6 +682,7 @@ class FusedMoE(CustomOp):
|
||||
quant_method=self.quant_method,
|
||||
reduce_results=self.reduce_results,
|
||||
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
||||
fused_shared_output=self.fused_shared_output,
|
||||
)
|
||||
|
||||
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
|
||||
@@ -681,7 +709,7 @@ class FusedMoE(CustomOp):
|
||||
# routing_tables only needed for round-robin expert placement with
|
||||
# DeepEP all2all backend.
|
||||
routing_tables = self._maybe_init_expert_routing_tables()
|
||||
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
|
||||
prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
|
||||
routing_tables=routing_tables
|
||||
)
|
||||
if prepare_finalize is not None:
|
||||
@@ -691,7 +719,7 @@ class FusedMoE(CustomOp):
|
||||
self._replace_quant_method(
|
||||
FusedMoEModularMethod.make(
|
||||
self,
|
||||
self.quant_method,
|
||||
self.base_quant_method,
|
||||
prepare_finalize,
|
||||
self.shared_experts,
|
||||
inplace=not self.moe_config.disable_inplace,
|
||||
@@ -959,11 +987,7 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
try:
|
||||
expert_data.copy_(loaded_weight)
|
||||
except Exception as e:
|
||||
print(expert_data.shape, expert_data.dtype, loaded_weight.shape, loaded_weight.dtype)
|
||||
raise e
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(
|
||||
self,
|
||||
@@ -976,7 +1000,7 @@ class FusedMoE(CustomOp):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
shard_size = loaded_weight.shape[shard_dim] // self.tp_size
|
||||
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
|
||||
# and we're not loading the full weight
|
||||
if not load_full and loaded_weight.ndim > 0:
|
||||
@@ -984,7 +1008,55 @@ class FusedMoE(CustomOp):
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
expert_data.narrow(shard_dim, 0, shard_size).copy_(loaded_weight)
|
||||
|
||||
def _load_model_opt_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
shard_dim_scale: int,
|
||||
expert_data: torch.Tensor,
|
||||
scale_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
opt_level: int,
|
||||
load_full_w2: bool = False):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
|
||||
assert opt_level in [1, 2]
|
||||
if opt_level == 1:
|
||||
weight, scale = weight_quant_l1(loaded_weight)
|
||||
else:
|
||||
weight, scale = weight_quant_l2(loaded_weight)
|
||||
scale = scale.view(1, -1)
|
||||
|
||||
if shard_id == "w2":
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2)
|
||||
scale_data.copy_(scale)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim_scale,
|
||||
loaded_weight=scale,
|
||||
expert_data=scale_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_single_value(
|
||||
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
|
||||
@@ -1147,7 +1219,6 @@ class FusedMoE(CustomOp):
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
shard_dim_force = getattr(param, "shard_dim", None)
|
||||
shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim
|
||||
|
||||
@@ -1309,13 +1380,28 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
if self.opt_level != 0:
|
||||
scale_name = weight_name.split('.')[-1] + "_scale"
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict[scale_name]
|
||||
shard_dim_scale = getattr(scale_param, "shard_dim", None)
|
||||
scale_expert_data = scale_param.data if full_load else scale_param.data[expert_id]
|
||||
self._load_model_opt_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
shard_dim_scale=shard_dim_scale,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
scale_data=scale_expert_data,
|
||||
opt_level=self.opt_level,
|
||||
tp_rank=self.tp_rank)
|
||||
else:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
@@ -56,25 +57,25 @@ logger = init_logger(__name__)
|
||||
# MoE kernel implementations.
|
||||
#
|
||||
# The following main classes are defined:
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# finalize method, informed by the FusedMoEExpertsModular method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# * FusedMoEExpertsModular - an abstract base class for the main fused
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# Some FusedMoEExpertsModular implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# by the FusedMoEExpertsModular implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
# class `FusedMoEPrepareAndFinalizeModular` since they could use collective
|
||||
# communication mechanisms that need to be consistent.
|
||||
#
|
||||
|
||||
@@ -155,25 +156,96 @@ PrepareResultType = tuple[
|
||||
torch.Tensor | None,
|
||||
]
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - dispatched router logits.
|
||||
#
|
||||
# See `prepare_monolithic` method below.
|
||||
#
|
||||
PrepareMonolithicResultType = tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
################################################################################
|
||||
# Prepare/Finalize
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above.
|
||||
|
||||
There are two variants of this class:
|
||||
* FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
|
||||
* FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
|
||||
"""
|
||||
|
||||
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
|
||||
def post_init_setup(self, fused_experts: "FusedMoEExperts"):
|
||||
"""
|
||||
Initialize FusedMoEPrepareAndFinalize settings that depend on
|
||||
FusedMoEPermuteExpertsUnpermute experts object.
|
||||
The FusedMoEPrepareAndFinalize implementations that have such
|
||||
Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
|
||||
FusedMoEExpertsModular experts object.
|
||||
The FusedMoEPrepareAndFinalizeModular implementations that have such
|
||||
dependencies may choose to override this function.
|
||||
"""
|
||||
return
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the Modular case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a tuple of:
|
||||
@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
||||
class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the monolithic case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> PrepareMonolithicResultType:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
################################################################################
|
||||
# Experts
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
class FusedMoEExperts(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
raise NotImplementedError("Implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
"""
|
||||
@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
#
|
||||
# Various helpers for registering support for various features.
|
||||
# Used by the oracle to select a particular kernel for a deployment.
|
||||
@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
@staticmethod
|
||||
def is_supported_config(
|
||||
cls: type["FusedMoEPermuteExpertsUnpermute"],
|
||||
cls: type["FusedMoEExperts"],
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return False, _make_reason(
|
||||
f"parallel config {moe_config.moe_parallel_config}"
|
||||
)
|
||||
elif not cls._supports_routing_method(
|
||||
moe_config.routing_method, weight_key, activation_key
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif not cls._supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype,
|
||||
moe_config.routing_method,
|
||||
):
|
||||
return False, _make_reason(
|
||||
f"router logits dtype {moe_config.router_logits_dtype}"
|
||||
)
|
||||
elif not cls._supports_shape(moe_config.hidden_dim):
|
||||
return False, _make_reason(
|
||||
f"{moe_config.hidden_dim} hidden dim is not supported"
|
||||
)
|
||||
elif activation_format != cls.activation_format():
|
||||
return False, _make_reason(f"{activation_format.value} activation format")
|
||||
return True, None
|
||||
@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@abstractmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
Whether the kernel supports deployment in expert parallel.
|
||||
Whether the kernel supports deployment in particular parallel config.
|
||||
|
||||
Can be overriden if a kernel does not support EP, SP or some other
|
||||
configuration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain routers are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular dtype for router logits input.
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain dtypes are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular shape. Can be overridden if a kernel
|
||||
has specific shape requirements.
|
||||
"""
|
||||
return True
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return False
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
) -> None:
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEExpertsMonolithic(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above, but with the monolithic interface (accepts router logits
|
||||
rather than topk ids and weights).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Monolithic kernels should explicitly opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a dtype for router logits.
|
||||
|
||||
Modular kernels should opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as apply(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is useful for kernels
|
||||
with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
@@ -802,75 +1006,32 @@ def _slice_scales(
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Kernel
|
||||
################################################################################
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
a FusedMoEPermuteExpertsUnpermute to provide an interface that
|
||||
is compatible with the `fused_experts` function in fused_moe.py.
|
||||
|
||||
It takes care of managing any required scratch space.
|
||||
|
||||
Note: Instances of this class should only be used for a single model
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
|
||||
class FusedMoEKernelModularImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
fused_experts: FusedMoEExpertsModular,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.moe_parallel_config = moe_parallel_config
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
# if not provided, assume this kernel is
|
||||
# running in a non-DP+EP context
|
||||
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
|
||||
self.is_dp_ep = (
|
||||
moe_parallel_config is not None
|
||||
and moe_parallel_config.dp_size > 1
|
||||
and moe_parallel_config.use_ep
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_format()
|
||||
), (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_format()}"
|
||||
)
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.fused_experts)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
|
||||
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
@@ -1172,9 +1333,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
# kernels. CUDAGraph compatible all2all kernels like the DeepEP
|
||||
# low-latency kernels are always batched and can never run into
|
||||
# the tensor.numel() == 0 case.
|
||||
if M_full == 0:
|
||||
assert num_chunks == 0
|
||||
workspace13 = None
|
||||
@@ -1313,19 +1474,18 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def forward(
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@@ -1335,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (MoEActivation): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -1355,23 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
from .fused_moe import fused_experts as fused_experts_kernel
|
||||
|
||||
result = fused_experts_kernel(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
quant_config=kwargs.get("quant_config", None),
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
return result
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
@@ -1417,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernelMonolithicImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
|
||||
fused_experts: FusedMoEExpertsMonolithic,
|
||||
):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as forward(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is used for kernels
|
||||
that have fused router + experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
|
||||
# TODO(rob): add inplace support.
|
||||
a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
router_logits=router_logits,
|
||||
quant_config=self.fused_experts.quant_config,
|
||||
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
|
||||
)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
hidden_states=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
a1q_scale=a1q_scale,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
output = self.prepare_finalize.finalize(fused_out)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernel:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEExperts,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = shared_experts # NOTE: check if we can remove
|
||||
|
||||
# Initialize the implementation (monolithic or modular).
|
||||
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
|
||||
if isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeModular
|
||||
) and isinstance(fused_experts, FusedMoEExpertsModular):
|
||||
self.impl = FusedMoEKernelModularImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
shared_experts,
|
||||
moe_parallel_config,
|
||||
inplace,
|
||||
)
|
||||
|
||||
elif isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
|
||||
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
|
||||
assert shared_experts is None
|
||||
assert not inplace
|
||||
self.impl = FusedMoEKernelMonolithicImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"prepare_finalize and fused_experts must both be either monolithic "
|
||||
f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
|
||||
f"and {fused_experts.__class__.__name__}"
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
|
||||
@property
|
||||
def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
|
||||
return self.impl.prepare_finalize
|
||||
|
||||
@property
|
||||
def fused_experts(self) -> FusedMoEExperts:
|
||||
return self.impl.fused_experts
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.impl.fused_experts)
|
||||
assert (
|
||||
self.prepare_finalize.activation_format
|
||||
== self.fused_experts.activation_format()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelModularImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using MoRI kernels.
|
||||
"""
|
||||
|
||||
@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
prepare_fp8_moe_layer_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -103,9 +99,13 @@ def _get_priority_backends(
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: Fp8MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
) -> type[mk.FusedMoEExperts]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
|
||||
TrtLlmFp8Experts,
|
||||
)
|
||||
|
||||
return TrtLlmFp8Experts
|
||||
|
||||
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
allow_vllm_cutlass: bool = False,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
|
||||
|
||||
if config.is_lora_enabled:
|
||||
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
|
||||
|
||||
@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
|
||||
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
|
||||
)
|
||||
|
||||
# Handle FLASHINFER_TRTLLM specially (no kernel class).
|
||||
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
else:
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
elif fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
if fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
backend = Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
assert fi_backend == FlashinferMoeBackend.CUTEDSL
|
||||
raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
|
||||
|
||||
raise ValueError(
|
||||
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try both.
|
||||
for backend in [
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Create FusedMoEQuantConfig for the specified FP8 Backend.
|
||||
The FusedMoEQuantConfig holds the scales that are used
|
||||
@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
|
||||
In a future PR, we will have this function should be
|
||||
a method of the modular kernel itself.
|
||||
"""
|
||||
# TRTLLM does not use Modular Kernel abstraction yet.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN is mixed precision W8A16 config.
|
||||
if fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
|
||||
assert a1_scale is not None and a2_scale is not None
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
|
||||
a2_scale=a2_scale,
|
||||
a1_gscale=(1.0 / a1_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
g1_alphas=(w1_scale * a1_scale).squeeze(),
|
||||
g2_alphas=(w2_scale * a2_scale).squeeze(),
|
||||
)
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
|
||||
def make_fp8_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
@@ -603,9 +555,9 @@ def make_fp8_moe_kernel(
|
||||
)
|
||||
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explict in
|
||||
# if using all2all (for SBO). bnell is making this explicit in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
|
||||
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_supported_config_trtllm,
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: NvFp4MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError(
|
||||
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
|
||||
return [
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return FlashInferExperts
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
|
||||
return FlashInferCuteDSLExperts
|
||||
return [FlashInferCuteDSLExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
|
||||
return CutlassExpertsFp4
|
||||
return [CutlassExpertsFp4]
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return MarlinExperts
|
||||
return [MarlinExperts]
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
||||
|
||||
@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
"""
|
||||
Select the primary NvFP4 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
@@ -143,10 +149,7 @@ def select_nvfp4_moe_backend(
|
||||
# NOTE(rob): this is kind of a hack. We need to peak into
|
||||
# the prepare-finalize selection to determine if we are using
|
||||
# the batched or standard expert format.
|
||||
use_batched = (
|
||||
config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or config.moe_parallel_config.use_pplx_kernels
|
||||
)
|
||||
use_batched = config.moe_parallel_config.use_deepep_ll_kernels
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if use_batched
|
||||
@@ -178,29 +181,21 @@ def select_nvfp4_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -213,36 +208,14 @@ def select_nvfp4_moe_backend(
|
||||
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
else:
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
else:
|
||||
backend = fi_2_vllm_backend_map[fi_backend]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try each.
|
||||
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
@@ -250,13 +223,13 @@ def select_nvfp4_moe_backend(
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, None
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
|
||||
@@ -271,16 +244,7 @@ def select_nvfp4_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None # type: ignore[assignment]
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
@@ -289,11 +253,11 @@ def select_nvfp4_moe_backend(
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
raise NotImplementedError(
|
||||
"No NvFp4 MoE backend supports the deployment configuration."
|
||||
@@ -401,12 +365,8 @@ def make_nvfp4_moe_quant_config(
|
||||
w2_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
|
||||
if backend in UNSUPPORTED:
|
||||
return None
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
) -> FusedMoEQuantConfig:
|
||||
if backend == NvFp4MoeBackend.MARLIN:
|
||||
return nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
@@ -423,22 +383,27 @@ def make_nvfp4_moe_quant_config(
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
# NOTE(rob): this is a hack until the MoE kernels
|
||||
# create their own quant configs. TRTLLM kernel
|
||||
# does not accept swizzled input quant scales.
|
||||
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
@@ -461,9 +426,9 @@ def make_nvfp4_moe_kernel(
|
||||
)
|
||||
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explict in
|
||||
# if using all2all (for SBO). bnell is making this explicit in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
|
||||
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_bf16,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
|
||||
backend: UnquantizedMoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> mk.FusedMoEModularKernel | None:
|
||||
) -> mk.FusedMoEKernel | None:
|
||||
if backend in UNSUPPORTED_BACKEND:
|
||||
return None
|
||||
|
||||
@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
FlashInferExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
AiterExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -241,25 +241,6 @@ def make_unquantized_moe_kernel(
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe import XPUExperts
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
XPUExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
kernel = fused_experts
|
||||
return kernel
|
||||
|
||||
@@ -1,373 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import pplx_kernels as pplx
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_validate_scale_shape,
|
||||
moe_kernel_quantize_input,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens: int,
|
||||
hidden_dim: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
):
|
||||
# All pplx byte sizes must be 16-byte aligned.
|
||||
align = 16
|
||||
|
||||
# For blocked per token: set to
|
||||
# cdiv(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
|
||||
if quant_dtype is not None:
|
||||
assert isinstance(quant_dtype, torch.dtype)
|
||||
assert quant_dtype.itemsize == 1
|
||||
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
|
||||
elem_size = torch.float32.itemsize
|
||||
|
||||
if per_act_token_quant:
|
||||
# per-token (M x 1)
|
||||
assert block_shape is None
|
||||
hidden_scale_bytes = elem_size
|
||||
elif block_shape is not None:
|
||||
# per-group (M x K_tiles)
|
||||
block_size = block_shape[1]
|
||||
num_blocks = cdiv(hidden_dim, block_size)
|
||||
hidden_scale_bytes = num_blocks * elem_size
|
||||
else:
|
||||
# per-tensor (1 x 1)
|
||||
hidden_scale_bytes = elem_size
|
||||
else:
|
||||
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
|
||||
hidden_scale_bytes = 0
|
||||
|
||||
return (
|
||||
round_up(hidden_dim_bytes, align),
|
||||
round_up(hidden_scale_bytes, align),
|
||||
)
|
||||
|
||||
|
||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""PPLX-based prepare and finalize for expert parallelism."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a2a: pplx.AllToAll,
|
||||
max_num_tokens: int,
|
||||
num_local_experts: int,
|
||||
num_dispatchers: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert max_num_tokens > 0
|
||||
assert num_local_experts > 0
|
||||
self.a2a = a2a
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.uint32
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
assert topk_ids.size(0) == num_tokens
|
||||
# expert_map should be None because with expert map, -1 id is used for
|
||||
# non-local token; this causes error when casting ids to the
|
||||
# topk_indices_dtype() int32
|
||||
#
|
||||
if expert_map is not None:
|
||||
logger.warning_once(
|
||||
"The PPLX backend does not support expert mapping. "
|
||||
"The provided `expert_map` will be ignored."
|
||||
)
|
||||
expert_map = None # noqa: F841
|
||||
|
||||
# Is this always going to be a1.device?
|
||||
device = a1.device
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
repeat_cols = 4
|
||||
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
||||
# TODO(bnell): always pass quant_config.a1_scale?
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
_validate_scale_shape(
|
||||
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
|
||||
)
|
||||
|
||||
orig_a_scale_block_shape: int | None = None
|
||||
|
||||
if a1q_scale is not None:
|
||||
scalar_scales = a1q_scale.numel() == 1
|
||||
|
||||
# pplx requires 2-d scales even for scalar scales
|
||||
if a1q_scale.dim() <= 1:
|
||||
assert scalar_scales
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
|
||||
orig_a_scale_block_shape = a1q_scale.shape[-1]
|
||||
|
||||
if not quant_config.is_block_quantized:
|
||||
# TODO (bnell): use group_broadcast instead?
|
||||
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
|
||||
|
||||
assert a1q_scale is None or a1q_scale.ndim == 2, (
|
||||
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
|
||||
)
|
||||
|
||||
expert_num_tokens = torch.empty(
|
||||
self.num_local_experts,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expert_x = torch.empty(
|
||||
(
|
||||
self.num_local_experts,
|
||||
self.max_num_tokens * self.num_dispatchers(),
|
||||
hidden_dim,
|
||||
),
|
||||
dtype=a1q.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expert_x_scale: torch.Tensor | None = None
|
||||
if a1q.dtype.itemsize == 1:
|
||||
if quant_config.is_per_act_token:
|
||||
# (M x 1) -> (E x M x K)
|
||||
final_dim = expert_x.size(2)
|
||||
elif quant_config.is_per_tensor:
|
||||
# (1 x 1) -> (E x 1 x 1)
|
||||
final_dim = 1
|
||||
else:
|
||||
# (M x K_tiles) -> (E x M x K_tiles)
|
||||
assert quant_config.block_shape is not None
|
||||
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
|
||||
final_dim = num_blocks
|
||||
|
||||
expert_x_scale_shape = (
|
||||
self.num_local_experts,
|
||||
expert_x.size(1),
|
||||
round_up(final_dim, 4), # round up for alignment
|
||||
)
|
||||
|
||||
expert_x_scale = torch.empty(
|
||||
expert_x_scale_shape,
|
||||
dtype=torch.float32,
|
||||
device=expert_x.device,
|
||||
)
|
||||
|
||||
# This argument is optional, defaults to indices.size(0)
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: torch.Tensor | None = None
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
hook = lambda: self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
return (
|
||||
hook,
|
||||
lambda: self._receiver(
|
||||
expert_num_tokens,
|
||||
expert_x,
|
||||
expert_x_scale,
|
||||
orig_a_scale_block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: torch.Tensor | None,
|
||||
orig_a_scale_block_shape: int | None,
|
||||
) -> mk.PrepareResultType:
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||
assert expert_x_scale.ndim == 3
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
|
||||
)
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
defer_input_quant=defer_input_quant,
|
||||
)
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
|
||||
"Weight application and reduction happens in the combine kernel."
|
||||
)
|
||||
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
bound_m: torch.Tensor | None = None
|
||||
|
||||
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
|
||||
# num_tokens = output.size(0) # M
|
||||
# assert topk_ids.size(0) == num_tokens, (
|
||||
# f"{topk_ids.size(0)} == {num_tokens}")
|
||||
assert topk_ids.size() == topk_weights.size(), (
|
||||
f"{topk_ids.size()} == {topk_weights.size()}"
|
||||
)
|
||||
assert output.size(0) <= self.max_num_tokens, (
|
||||
f"{output.size(0)} <= {self.max_num_tokens}"
|
||||
)
|
||||
assert output.size(1) == fused_expert_output.size(-1)
|
||||
|
||||
# Set weights to 1 if we did them in dispatch. This is hacky.
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
|
||||
|
||||
self.a2a.combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids_u32,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
return lambda: self.a2a.combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids_u32,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
receiver = self.finalize_async(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
)
|
||||
receiver()
|
||||
@@ -1,209 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quantization to the MoE kernel.
|
||||
use_nvfp4 = quant_config.use_nvfp4_w4a4
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
if skip_gather_scales:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
"""MoE prepare and finalize without expert parallelism."""
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if defer_input_quant:
|
||||
return a1, None, None, None, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
|
||||
MoEPrepareAndFinalizeNaiveDPEPModular,
|
||||
MoEPrepareAndFinalizeNaiveDPEPMonolithic,
|
||||
make_moe_prepare_and_finalize_naive_dp_ep,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
MoEPrepareAndFinalizeNoDPEPMonolithic,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MoEPrepareAndFinalizeNaiveDPEPMonolithic",
|
||||
"MoEPrepareAndFinalizeNaiveDPEPModular",
|
||||
"make_moe_prepare_and_finalize_naive_dp_ep",
|
||||
"MoEPrepareAndFinalizeNoDPEPMonolithic",
|
||||
"MoEPrepareAndFinalizeNoDPEPModular",
|
||||
"make_moe_prepare_and_finalize_no_dp_ep",
|
||||
]
|
||||
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def _quantize_and_setup_dispatch(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
|
||||
# Defer input quantization to the MoE kernel.
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
return a1q, scales
|
||||
|
||||
|
||||
def _unwrap_scale_and_prepare_for_moe(
|
||||
scales: list[torch.Tensor] | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> torch.Tensor:
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
# Apply swizzling after a2a if the MoE kernel needs it.
|
||||
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q_scale
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
|
||||
|
||||
Uses Torch AR/RS or AR for dispatch/combine operations, applied
|
||||
to the topk weights and ids.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
"""Quantize and Dispatch Topk Weights and Topk Ids."""
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
|
||||
"""
|
||||
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
|
||||
|
||||
Uses Torch AR/RS or AR for dispatch/combine operations, applied
|
||||
to the router logits (the MoE kernel runs the router internally).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
"""Quantize and Dispatch Router Logits."""
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
|
||||
res = get_ep_group().dispatch_router_logits(
|
||||
a1q,
|
||||
router_logits,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
a1q, router_logits = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, router_logits, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
return a1q, a1q_scale, router_logits
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
fused_expert_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out = get_ep_group().combine(
|
||||
fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
use_monolithic: bool,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
|
||||
return (
|
||||
MoEPrepareAndFinalizeNaiveDPEPMonolithic(
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
if use_monolithic
|
||||
else MoEPrepareAndFinalizeNaiveDPEPModular(
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
|
||||
|
||||
def _quantize_input(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if defer_input_quant:
|
||||
return a1, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
|
||||
return a1q, a1q_scale, router_logits
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
fused_expert_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return fused_expert_output
|
||||
|
||||
|
||||
def make_moe_prepare_and_finalize_no_dp_ep(
|
||||
use_monolithic: bool,
|
||||
) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic:
|
||||
return (
|
||||
MoEPrepareAndFinalizeNoDPEPMonolithic()
|
||||
if use_monolithic
|
||||
else MoEPrepareAndFinalizeNoDPEPModular()
|
||||
)
|
||||
@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
|
||||
)
|
||||
|
||||
|
||||
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
|
||||
self._device_buffer = torch.zeros(
|
||||
(max_num_batched_tokens, num_layers, num_experts_per_tok),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
)
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# `FusedMoEPrepareAndFinalizeModular` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
@@ -175,6 +175,7 @@ class BaseRouter(FusedMoERouter):
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_ids
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -31,7 +31,7 @@ def vllm_topk_softmax(
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
e_score_correction_bias
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
@@ -85,13 +85,14 @@ def fused_topk_bias(
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = vllm_topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
gating_output_float,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
@@ -186,7 +187,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
# if self.routed_scaling_factor != 1.0:
|
||||
# topk_weights *= self.routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -26,8 +26,9 @@ def vllm_topk_softmax(
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@@ -90,13 +91,14 @@ def fused_topk(
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
gating_output_float = gating_output.float()
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_func = dispatch_topk_softmax_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
@@ -105,7 +107,7 @@ def fused_topk(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
115
vllm/model_executor/layers/fused_moe/router/gate_linear.py
Normal file
115
vllm/model_executor/layers/fused_moe/router/gate_linear.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@PluggableLayer.register("gate_linear")
|
||||
class GateLinear(ReplicatedLinear):
|
||||
"""MoE gate linear layer with three-tier GEMM dispatch:
|
||||
|
||||
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
|
||||
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
3. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
|
||||
The ``out_dtype`` attribute is mutable and can be set after init
|
||||
(e.g. when the required dtype depends on the expert quantization
|
||||
method which is only known later).
|
||||
"""
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
force_fp32_compute: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
can_use_specialized_kernels = False
|
||||
|
||||
# If fp32 compute is required and no specialized kernel is available,
|
||||
# store weights in fp32 so Tier 3 computes in fp32 natively.
|
||||
if force_fp32_compute and not can_use_specialized_kernels:
|
||||
params_dtype = torch.float32
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
bias=bias,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
# DSV3 specialized kernel eligibility (SM90+, exact dims)
|
||||
self.allow_specialized_router_gemm = can_use_specialized_kernels
|
||||
self.allow_dsv3_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# cuBLAS bf16→fp32 eligibility
|
||||
self.allow_cublas_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
and self.weight.dtype == torch.bfloat16
|
||||
and self.out_dtype == torch.float32
|
||||
)
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""Set output dtype for the router logits after init.
|
||||
|
||||
Useful when the required dtype depends on the expert quantization
|
||||
method which is only known after the gate is constructed.
|
||||
"""
|
||||
if self.out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
if (
|
||||
not self.allow_cublas_router_gemm
|
||||
and self.allow_specialized_router_gemm
|
||||
and out_dtype == torch.float32
|
||||
):
|
||||
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
# Tier 1: DSV3 specialized kernel
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
output = ops.dsv3_router_gemm(
|
||||
hidden_states=x,
|
||||
router_weight=self.weight,
|
||||
output_dtype=self.out_dtype,
|
||||
)
|
||||
return output, None
|
||||
|
||||
# Tier 2: cuBLAS bf16→fp32
|
||||
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
|
||||
output = ops.router_gemm_bf16_fp32(x, self.weight)
|
||||
return output, None
|
||||
|
||||
# Tier 3: F.linear (ReplicatedLinear)
|
||||
if self.out_dtype is not None and x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
output, output_bias = super().forward(x)
|
||||
if self.out_dtype is not None and output.dtype != self.out_dtype:
|
||||
output = output.to(self.out_dtype)
|
||||
return output, output_bias
|
||||
@@ -92,77 +92,9 @@ def grouped_topk(
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
and current_platform.is_cuda()
|
||||
and num_expert_group <= 32
|
||||
and topk <= 32
|
||||
and e_score_correction_bias is not None
|
||||
):
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
topk_weights, topk_ids = grouped_topk(gating_output, topk, num_expert_group, topk_group, scoring_func, e_score_correction_bias,renormalize = renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@@ -246,7 +178,6 @@ class GroupedTopk(CustomOp):
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
|
||||
class GroupedTopKRouter(BaseRouter):
|
||||
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
|
||||
@@ -316,8 +247,8 @@ class GroupedTopKRouter(BaseRouter):
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
# if self.routed_scaling_factor != 1.0:
|
||||
# topk_weights *= self.routed_scaling_factor
|
||||
else:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
@@ -340,14 +271,14 @@ class GroupedTopKRouter(BaseRouter):
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
# hidden_states=hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
# routed_scaling_factor=self.routed_scaling_factor,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def create_fused_moe_router(
|
||||
# grouped topk + fused topk bias parameters
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
# custom routing paramaters
|
||||
# custom routing parameters
|
||||
custom_routing_function: Callable | None = None,
|
||||
# eplb parameters
|
||||
enable_eplb: bool = False,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import (
|
||||
HAS_OPAQUE_TYPE,
|
||||
ModuleName,
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
||||
return forward_context.no_compile_layers[layer_name]
|
||||
|
||||
|
||||
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
|
||||
# on older versions it remains a plain str.
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
_layer_name_type: TypeAlias = str | ModuleName
|
||||
else:
|
||||
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
|
||||
|
||||
|
||||
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
||||
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||
|
||||
|
||||
def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
@@ -74,7 +91,7 @@ def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -83,9 +100,9 @@ def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
|
||||
# hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
@@ -165,6 +180,7 @@ class DefaultMoERunner(MoERunner):
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
fused_shared_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_config = moe_config
|
||||
@@ -175,6 +191,9 @@ class DefaultMoERunner(MoERunner):
|
||||
self.quant_method = quant_method
|
||||
self.reduce_results = reduce_results
|
||||
self.enable_dbo = enable_dbo
|
||||
self.fused_shared_output = fused_shared_output
|
||||
if self.fused_shared_output:
|
||||
assert self.shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
|
||||
|
||||
# Allow disabling of the separate shared experts stream for
|
||||
# debug purposes.
|
||||
@@ -195,19 +214,19 @@ class DefaultMoERunner(MoERunner):
|
||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||
self.layer_name = layer.layer_name
|
||||
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = _moe_forward
|
||||
else:
|
||||
self.moe_forward = _moe_forward_shared
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = _moe_forward
|
||||
else:
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward
|
||||
else:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
self.moe_forward = _moe_forward_shared
|
||||
# else:
|
||||
# if self.shared_experts is None:
|
||||
# self.moe_forward = torch.ops.vllm.moe_forward
|
||||
# else:
|
||||
# self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
@@ -216,8 +235,7 @@ class DefaultMoERunner(MoERunner):
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
return (
|
||||
self.moe_config.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.moe_parallel_config.use_mori_kernels
|
||||
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
@@ -306,8 +324,8 @@ class DefaultMoERunner(MoERunner):
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.quant_method.moe_mk is not None
|
||||
and self.quant_method.moe_mk.output_is_reduced()
|
||||
self.quant_method.moe_kernel is not None
|
||||
and self.quant_method.moe_kernel.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
@@ -362,13 +380,15 @@ class DefaultMoERunner(MoERunner):
|
||||
|
||||
if isinstance(states, tuple):
|
||||
return tuple(
|
||||
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
[None if s is None else func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
)
|
||||
else:
|
||||
assert len(trunc_sizes) == 1
|
||||
return func(states, trunc_sizes[0])
|
||||
|
||||
def _encode_layer_name(self) -> str:
|
||||
def _encode_layer_name(self) -> str | ModuleName:
|
||||
if HAS_OPAQUE_TYPE:
|
||||
return ModuleName(self.layer_name)
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
@@ -624,53 +644,27 @@ class DefaultMoERunner(MoERunner):
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
post_quant_allgather = (
|
||||
self.quant_method is not None
|
||||
and self.moe_config.dp_size > 1
|
||||
and self.moe_config.use_ep
|
||||
and getattr(self.quant_method, "do_post_quant_allgather", False)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
layer, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
(
|
||||
orig_hidden_states,
|
||||
router_logits,
|
||||
extra_tensors_combined,
|
||||
) = dispatch_res
|
||||
hidden_states_combined = (
|
||||
orig_hidden_states,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
orig_hidden_states = hidden_states_combined
|
||||
else:
|
||||
orig_hidden_states = hidden_states
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
if has_separate_shared_experts: # and not use_shared_experts_stream:
|
||||
assert self.shared_experts is not None
|
||||
shared_input = (
|
||||
shared_input if shared_input is not None else hidden_states
|
||||
)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
else:
|
||||
assert self.fused_shared_output == False, "fused_shared_output is only supported when has_separate_shared_experts is True"
|
||||
shared_output = None
|
||||
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
|
||||
# router logits to all experts.
|
||||
# NOTE: this will be removed once all kernels are migrated into the
|
||||
# MoEKernel framework.
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
@@ -685,42 +679,33 @@ class DefaultMoERunner(MoERunner):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
|
||||
# Figure out nicer way to do this.
|
||||
if do_naive_dispatch_combine:
|
||||
x = hidden_states_combined
|
||||
x_orig = orig_hidden_states
|
||||
else:
|
||||
x = hidden_states
|
||||
x_orig = hidden_states
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=x,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=x_orig,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x, # The type signture of this is wrong due to the hack.
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
router_logits=router_logits,
|
||||
top_k=topk_ids.shape[-1]
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input=shared_output if self.fused_shared_output else None,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if use_shared_experts_stream:
|
||||
assert use_shared_experts_stream == False, "Running shared experts in parallel with the main MoE execution is currently not supported!"
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
@@ -733,7 +718,7 @@ class DefaultMoERunner(MoERunner):
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
None if self.fused_shared_output else shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,14 +10,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
Useful in the case when some FusedMoEPermuteExpertsUnpermute
|
||||
Useful in the case when some FusedMoEExpertsModular
|
||||
implementation does not perform weight application and reduction
|
||||
but cannot address the needs of all the compatible PrepareAndFinalize
|
||||
implementations.
|
||||
For example, BatchedTritonExperts is compatible with both
|
||||
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
|
||||
does the weight-application + reduction as part of the pplx combine kernel.
|
||||
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
|
||||
For example, BatchedTritonExperts is compatible with both batched
|
||||
PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and
|
||||
BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do
|
||||
the weight-application + reduction as part of the combine kernel, while
|
||||
BatchedPrepareAndFinalize needs an explicit implementation. To facilitate
|
||||
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
|
||||
so the PrepareAndFinalize implementations could choose how to
|
||||
weight + reduce.
|
||||
@@ -61,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
if output is None:
|
||||
return fused_expert_output
|
||||
|
||||
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
|
||||
# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
|
||||
# tensor.
|
||||
assert output.size() == fused_expert_output.size(), (
|
||||
"output shape is expected to match the fused_expert_output shape. "
|
||||
|
||||
@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
return (CutlassExpertsFp8, TritonExperts)
|
||||
|
||||
@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and hidden_states.shape[0] <= 8:
|
||||
return self.fallback_experts
|
||||
|
||||
@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
return (DeepGemmExperts, TritonExperts)
|
||||
|
||||
@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
|
||||
return self.experts
|
||||
else:
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
|
||||
"""TensorRT-LLM-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
UnquantizedMoeBackend,
|
||||
@@ -42,9 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts
|
||||
else:
|
||||
TritonExperts = None # type: ignore
|
||||
fused_experts = None
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.rocm_aiter_moe_enabled = (
|
||||
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
self.kernel: mk.FusedMoEKernel | None = None
|
||||
self._is_monolithic = (
|
||||
current_platform.is_cpu()
|
||||
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
|
||||
return None
|
||||
else:
|
||||
@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
) -> FusedMoEExpertsModular:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
@@ -296,16 +296,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
**kwargs
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward(
|
||||
result = self.forward(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
# not used
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
) * layer.routed_scaling_factor
|
||||
if shared_experts_input is not None:
|
||||
result += shared_experts_input
|
||||
return result
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
if self.moe.has_bias:
|
||||
@@ -333,10 +337,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
expert_map=layer.expert_map
|
||||
)
|
||||
|
||||
def forward_monolithic_cuda(
|
||||
|
||||
@@ -23,7 +23,7 @@ if current_platform.is_xpu():
|
||||
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
|
||||
|
||||
|
||||
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
Reference in New Issue
Block a user