### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR #2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No further test cases needed.
- vLLM version: v0.10.1.1
- vLLM main:
d660c98c1b
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
172 lines
6.4 KiB
Python
172 lines
6.4 KiB
Python
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
|
|
FusedMoEConfig, FusedMoEParallelConfig)
|
|
|
|
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
|
|
AllGatherCommImpl, NativeAllGatherCommImpl)
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [16, 128])
|
|
@pytest.mark.parametrize("hidden_size", [64, 128])
|
|
@pytest.mark.parametrize("global_num_experts", [8, 16])
|
|
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
|
@pytest.mark.parametrize("top_k_num", [2, 4])
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
|
@pytest.mark.parametrize("ep_rank", [0, 1])
|
|
@pytest.mark.parametrize("apply_a8_quantization", [False])
|
|
def test_all_gather_comm_impl(
|
|
num_tokens,
|
|
hidden_size,
|
|
global_num_experts,
|
|
num_local_experts,
|
|
top_k_num,
|
|
dtype,
|
|
ep_rank,
|
|
apply_a8_quantization,
|
|
mocker,
|
|
):
|
|
"""
|
|
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
|
|
|
|
This test compares the outputs of the NPU-optimized AllGatherCommImpl
|
|
with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure
|
|
correctness across various configurations.
|
|
"""
|
|
if top_k_num > global_num_experts:
|
|
pytest.skip("top_k_num cannot be greater than global_num_experts")
|
|
if num_local_experts > global_num_experts:
|
|
pytest.skip(
|
|
"num_local_experts cannot be greater than global_num_experts")
|
|
|
|
device = torch.device("npu")
|
|
|
|
# mock get_tensor_model_parallel_rank to return ep_rank
|
|
mocker.patch(
|
|
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
|
|
return_value=ep_rank,
|
|
)
|
|
|
|
# make moe config
|
|
parallel_config = SimpleNamespace(
|
|
enable_expert_parallel=num_local_experts < global_num_experts)
|
|
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
|
tp_size_=max(2, global_num_experts // num_local_experts),
|
|
dp_size_=1,
|
|
vllm_parallel_config=parallel_config,
|
|
)
|
|
|
|
moe_config = FusedMoEConfig(
|
|
num_experts=global_num_experts,
|
|
experts_per_token=top_k_num,
|
|
hidden_dim=hidden_size,
|
|
num_local_experts=num_local_experts,
|
|
moe_parallel_config=moe_parallel_config,
|
|
in_dtype=dtype,
|
|
quant_config=None, # No quantization in this test
|
|
max_num_tokens=num_tokens,
|
|
)
|
|
|
|
# Instantiate implementations
|
|
native_impl = NativeAllGatherCommImpl(moe_config)
|
|
|
|
all_gather_impl = AllGatherCommImpl(moe_config)
|
|
|
|
# --- Input Data ---
|
|
hidden_states = torch.randn(num_tokens,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=dtype)
|
|
topk_ids = torch.randint(0,
|
|
global_num_experts, (num_tokens, top_k_num),
|
|
device=device,
|
|
dtype=torch.int32)
|
|
topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype)
|
|
topk_weights = torch.nn.functional.softmax(topk_weights, dim=1)
|
|
|
|
num_experts = global_num_experts
|
|
|
|
expert_map = None
|
|
if num_local_experts < global_num_experts:
|
|
# Create a map where some experts are local and some are not
|
|
expert_map = torch.full((global_num_experts, ), -1, device=device)
|
|
expert_map[ep_rank * num_local_experts:(ep_rank + 1) *
|
|
num_local_experts] = torch.arange(num_local_experts,
|
|
device=device)
|
|
num_experts = num_local_experts
|
|
|
|
# --- Run Native Implementation (Golden Reference) ---
|
|
native_hidden_states_out = hidden_states.clone()
|
|
(
|
|
native_permuted_hidden,
|
|
native_expert_tokens,
|
|
_,
|
|
_,
|
|
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
|
num_experts, apply_a8_quantization)
|
|
# Simulate MLP output
|
|
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
|
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
|
|
|
# --- Run AllGather Implementation ---
|
|
all_gather_hidden_states_out = hidden_states.clone()
|
|
(
|
|
all_gather_permuted_hidden,
|
|
all_gather_expert_tokens,
|
|
_,
|
|
_,
|
|
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
|
expert_map, num_experts, apply_a8_quantization)
|
|
|
|
# Use the same simulated MLP output for a fair comparison
|
|
all_gather_mlp_output = native_mlp_output.clone()
|
|
|
|
all_gather_impl.unpermute(all_gather_mlp_output,
|
|
all_gather_hidden_states_out)
|
|
|
|
# --- Assertions ---
|
|
# Define tolerance based on dtype
|
|
atol = 1e-3 if dtype == torch.float16 else 1e-2
|
|
rtol = 1e-3 if dtype == torch.float16 else 1e-2
|
|
|
|
# 1. Compare expert_tokens from pre_process
|
|
assert torch.allclose(native_expert_tokens.to(
|
|
all_gather_expert_tokens.device),
|
|
all_gather_expert_tokens,
|
|
atol=atol,
|
|
rtol=rtol), "Expert tokens do not match."
|
|
|
|
# 2. Compare permuted_hidden_states from pre_process
|
|
num_valid_tokens = native_expert_tokens.sum()
|
|
assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to(
|
|
all_gather_permuted_hidden.device),
|
|
all_gather_permuted_hidden[:num_valid_tokens],
|
|
atol=atol,
|
|
rtol=rtol), "Permuted hidden states do not match."
|
|
|
|
# 3. Compare final hidden_states from post_process
|
|
assert torch.allclose(native_hidden_states_out.to(
|
|
all_gather_hidden_states_out.device),
|
|
all_gather_hidden_states_out,
|
|
atol=atol,
|
|
rtol=rtol), "Final hidden states do not match."
|