[3/N][Feat][Graph] Support all-to-all and quantized models with ACL Graph (#2614)
### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR #2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No further test cases needed.
- vLLM version: v0.10.1.1
- vLLM main:
d660c98c1b
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
|
||||
@pytest.mark.parametrize("top_k_num", [2, 4])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("ep_rank", [0, 1])
|
||||
@pytest.mark.parametrize("apply_a8_quantization", [False])
|
||||
def test_all_gather_comm_impl(
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
@@ -41,6 +42,7 @@ def test_all_gather_comm_impl(
|
||||
top_k_num,
|
||||
dtype,
|
||||
ep_rank,
|
||||
apply_a8_quantization,
|
||||
mocker,
|
||||
):
|
||||
"""
|
||||
@@ -118,8 +120,9 @@ def test_all_gather_comm_impl(
|
||||
native_permuted_hidden,
|
||||
native_expert_tokens,
|
||||
_,
|
||||
_,
|
||||
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
||||
num_experts)
|
||||
num_experts, apply_a8_quantization)
|
||||
# Simulate MLP output
|
||||
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
||||
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
||||
@@ -130,8 +133,9 @@ def test_all_gather_comm_impl(
|
||||
all_gather_permuted_hidden,
|
||||
all_gather_expert_tokens,
|
||||
_,
|
||||
_,
|
||||
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
||||
expert_map, num_experts)
|
||||
expert_map, num_experts, apply_a8_quantization)
|
||||
|
||||
# Use the same simulated MLP output for a fair comparison
|
||||
all_gather_mlp_output = native_mlp_output.clone()
|
||||
|
||||
Reference in New Issue
Block a user