add dispatch_gmm_combine kernel (#3532)
### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -52,6 +52,7 @@ class MoECommType(Enum):
|
||||
ALLGATHER = 0
|
||||
MC2 = 1
|
||||
ALLTOALL = 2
|
||||
FUSED_ALLTOALL = 3
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -520,7 +520,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
return shared_out, fused_output
|
||||
|
||||
@@ -44,6 +44,8 @@ def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
|
||||
moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
@@ -243,3 +245,69 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
out = torch.empty_like(hidden_states)
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1,
|
||||
weight2=w2,
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale,
|
||||
scale2=w2_scale,
|
||||
probs=topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -513,6 +513,11 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=self.ep_group)
|
||||
backend = self.ep_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
|
||||
@@ -249,8 +249,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
|
||||
} or forward_context.sp_enabled:
|
||||
if moe_comm_type in {
|
||||
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
@@ -232,13 +233,15 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
fused_flag = get_forward_context(
|
||||
).moe_comm_type == MoECommType.FUSED_ALLTOALL
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w1=w1[0] if fused_flag else w1,
|
||||
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
|
||||
w2=w2[0] if fused_flag else w2,
|
||||
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
@@ -270,6 +273,12 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
|
||||
layer.fused_w1_scale = scale_from_float_to_int64(
|
||||
layer.w13_weight_scale.data)
|
||||
layer.fused_w2_scale = scale_from_float_to_int64(
|
||||
layer.w2_weight_scale.data)
|
||||
|
||||
if self.dynamic_eplb:
|
||||
layer.w13_weight_list = [
|
||||
weight.clone()
|
||||
@@ -292,3 +301,11 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def scale_from_float_to_int64(scale):
|
||||
import numpy as np
|
||||
scale = torch.from_numpy(
|
||||
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
|
||||
dtype=np.int32).astype(np.int64)).to(scale.device)
|
||||
return scale
|
||||
|
||||
@@ -911,6 +911,9 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
"dp": {
|
||||
"hccl_buffer_size": calculate_dp_buffer_size()
|
||||
},
|
||||
"ep": {
|
||||
"hccl_buffer_size": calculate_ep_buffer_size()
|
||||
},
|
||||
}
|
||||
return hccl_config_map.get(group_name, get_default_buffer_config())
|
||||
|
||||
@@ -932,6 +935,30 @@ def calculate_dp_buffer_size() -> int:
|
||||
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
|
||||
|
||||
|
||||
def calculate_ep_buffer_size() -> int:
|
||||
"""
|
||||
formula of ep buffer size:
|
||||
batch_size * hidden_size * topk * 4
|
||||
"""
|
||||
ep_buffer_size = _DEFAULT_BUFFER_SIZE
|
||||
try:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
hidden_size = hf_config.hidden_size
|
||||
topk = getattr(hf_config, "num_experts_per_token", 1)
|
||||
batch_size = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
int8_size = torch.iinfo(torch.int8).bits // 8
|
||||
bf16_size = torch.finfo(torch.bfloat16).bits // 8
|
||||
ep_buffer_size = math.ceil(
|
||||
(batch_size * hidden_size * topk *
|
||||
(int8_size * 2 + bf16_size)) / (1024 * 1024))
|
||||
except Exception:
|
||||
pass
|
||||
return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)
|
||||
|
||||
|
||||
# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1
|
||||
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
||||
# significantly improve communication performance of MC2 ops dispatch/combine.
|
||||
|
||||
@@ -2217,8 +2217,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
return None
|
||||
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
||||
'moe_quantize', None)
|
||||
quant_type = getattr(
|
||||
self.vllm_config.model_config.hf_config, 'moe_quantize',
|
||||
getattr(self.vllm_config.model_config.hf_config, 'quantize', None))
|
||||
model_type = self.vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not self.parallel_config.enable_expert_parallel:
|
||||
@@ -2237,7 +2238,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
elif soc_version in {AscendDeviceType._910_93}:
|
||||
moe_comm_type = (MoECommType.MC2
|
||||
if num_tokens <= self.mc2_tokens_capacity else
|
||||
MoECommType.ALLTOALL)
|
||||
MoECommType.FUSED_ALLTOALL if quant_type
|
||||
== "w8a8_dynamic" else MoECommType.ALLTOALL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user