Reapply "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483) (#3512)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.
3. Fix MC2 bug introduced in
https://github.com/vllm-project/vllm-ascend/pull/3365
4. Unify aclgraph & eager in W8A8_dynamic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-22 11:41:30 +08:00
committed by GitHub
parent 6ef62cb427
commit 2f1b9a7a64
13 changed files with 608 additions and 522 deletions

View File

@@ -77,9 +77,10 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map)
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
@@ -123,36 +124,68 @@ class TestTokenDispatcherWithMC2(TestBase):
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
assist_info_for_combine = torch.arange(10)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": mc2_mask,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"tp_recv_counts": tp_recv_counts
}
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
context_metadata)
self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self):
self.dispatcher.shared_experts = MagicMock()
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
10, 128), torch.tensor(1.0))
self.dispatcher.shared_act = torch.randn(10, 128)
shared_experts = MagicMock()
shared_experts.down_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": None,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"shared_experts": shared_experts,
"shared_act": torch.randn(10, 128),
"swiglu_out_scale": torch.randn(10, 1),
"tp_recv_counts": tp_recv_counts
}
self.dispatcher.with_quant = True
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.hidden_states = torch.randn(10, 128)
hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
self.dispatcher.token_combine(self.hidden_states)
result = self.dispatcher.token_combine(hidden_states,
context_metadata)
self.assertIsInstance(result, tuple)
class TestTokenDispatcherWithAllGather(TestBase):
@@ -264,35 +297,26 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.assertEqual(results["group_list_type"], 1)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
self.mock_npu_moe_token_unpermute.assert_called_once()
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_dispatch_with_router_weight(self):
@@ -418,25 +442,21 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.assertEqual(result["group_list_type"], 1)
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
context_metadata = {
"input_splits": [4, 4],
"output_splits": [4, 4],
"topk_weights": torch.rand(8, 4),
"reversed_local_input_permutation_mapping": torch.arange(8),
"reversed_global_input_permutation_mapping": torch.arange(16),
}
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
8)
self.dispatcher.topk_weights = torch.rand(8, 4)
self.dispatcher.input_splits = [4, 4]
self.dispatcher.output_splits = [4, 4]
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
[[2, 2], [2, 2]], dtype=torch.int64)
expert_output = torch.randn(16, 16)
output = self.dispatcher.token_combine(expert_output)
output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))