add dispatch_gmm_combine kernel (#3532)

### What this PR does / why we need it?

This PR introduces the Ascend implementation of the
`dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime,
together with follow‑up fixes to ensure the kernel builds and runs
correctly in CI.

- Add full host and device implementation of the `dispatch_ffn_combine`
kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE
routing helpers, and kernel utilities for quantized FFN dispatch.
- Integrate the new kernel with the PyTorch binding
(csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend
runtime (vllm_ascend/ascend_forward_context.py,
vllm_ascend/worker/model_runner_v1.py).
- Extend fused MoE communication and token dispatch support in
`vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new
dispatch path.
- Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py
to support the new FFN dispatch flow.
- Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake
configuration, and include/namespace usage in the new kernel files.
- Add an end‑to‑end nightly test
`tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper
utilities in `vllm_ascend/utils.py` to validate the new kernel.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0

---------

Signed-off-by: mojave2 <chenchen145@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Chen Chen
2025-12-04 23:00:59 +08:00
committed by GitHub
parent 752a55473c
commit ad0607f900
61 changed files with 9795 additions and 53 deletions

View File

@@ -52,6 +52,7 @@ class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_ALLTOALL = 3
@contextmanager

View File

@@ -520,7 +520,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_output

View File

@@ -44,6 +44,8 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
class MoECommMethod(ABC):
@@ -243,3 +245,69 @@ class AlltoAllCommImpl(MoECommMethod):
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
class FusedAlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1,
weight2=w2,
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
return out

View File

@@ -513,6 +513,11 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=self.ep_group)
backend = self.ep_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,

View File

@@ -249,8 +249,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
} or forward_context.sp_enabled:
if moe_comm_type in {
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
} or forward_context.sp_enabled:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)

View File

@@ -24,6 +24,7 @@ from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
@@ -232,13 +233,15 @@ class AscendW8A8DynamicFusedMoEMethod:
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_ALLTOALL
return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
w1=w1[0] if fused_flag else w1,
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
w2=w2[0] if fused_flag else w2,
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
@@ -270,6 +273,12 @@ class AscendW8A8DynamicFusedMoEMethod:
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
layer.fused_w1_scale = scale_from_float_to_int64(
layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(
layer.w2_weight_scale.data)
if self.dynamic_eplb:
layer.w13_weight_list = [
weight.clone()
@@ -292,3 +301,11 @@ class AscendW8A8DynamicFusedMoEMethod:
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
torch.npu.empty_cache()
def scale_from_float_to_int64(scale):
import numpy as np
scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
dtype=np.int32).astype(np.int64)).to(scale.device)
return scale

View File

@@ -911,6 +911,9 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
"dp": {
"hccl_buffer_size": calculate_dp_buffer_size()
},
"ep": {
"hccl_buffer_size": calculate_ep_buffer_size()
},
}
return hccl_config_map.get(group_name, get_default_buffer_config())
@@ -932,6 +935,30 @@ def calculate_dp_buffer_size() -> int:
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
def calculate_ep_buffer_size() -> int:
"""
formula of ep buffer size:
batch_size * hidden_size * topk * 4
"""
ep_buffer_size = _DEFAULT_BUFFER_SIZE
try:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
hf_config = vllm_config.model_config.hf_config
hidden_size = hf_config.hidden_size
topk = getattr(hf_config, "num_experts_per_token", 1)
batch_size = vllm_config.scheduler_config.max_num_batched_tokens
int8_size = torch.iinfo(torch.int8).bits // 8
bf16_size = torch.finfo(torch.bfloat16).bits // 8
ep_buffer_size = math.ceil(
(batch_size * hidden_size * topk *
(int8_size * 2 + bf16_size)) / (1024 * 1024))
except Exception:
pass
return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)
# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
# significantly improve communication performance of MC2 ops dispatch/combine.

View File

@@ -2217,8 +2217,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return None
soc_version = get_ascend_device_type()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
quant_type = getattr(
self.vllm_config.model_config.hf_config, 'moe_quantize',
getattr(self.vllm_config.model_config.hf_config, 'quantize', None))
model_type = self.vllm_config.model_config.hf_config.model_type
if not self.parallel_config.enable_expert_parallel:
@@ -2237,7 +2238,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
elif soc_version in {AscendDeviceType._910_93}:
moe_comm_type = (MoECommType.MC2
if num_tokens <= self.mc2_tokens_capacity else
MoECommType.ALLTOALL)
MoECommType.FUSED_ALLTOALL if quant_type
== "w8a8_dynamic" else MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")