This reverts commit4f937f561d. ### What this PR does / why we need it? This reverts commit4f937f561d. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -137,7 +137,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
@@ -145,10 +144,8 @@ def test_token_dispatcher_with_all_gather(
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||
bias=None)
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
@@ -218,7 +215,6 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
dynamic_scale = dispatch_output["dynamic_scale"]
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
|
||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1,
|
||||
@@ -229,10 +225,8 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=dynamic_scale,
|
||||
with_quant=True)
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||
bias=None)
|
||||
assert combined_output.shape == (m, k)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@@ -44,8 +44,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
@@ -53,9 +52,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -80,11 +77,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out, r_out, mask = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
@@ -100,9 +96,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
@@ -118,15 +112,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -142,11 +133,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
@@ -162,9 +152,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
@@ -207,9 +195,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
@@ -221,9 +209,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
return tensor[:3]
|
||||
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@@ -277,9 +263,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -59,18 +59,15 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@@ -93,8 +90,7 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2),
|
||||
torch.tensor([1, 0, 1,
|
||||
0]), None)
|
||||
torch.tensor([1, 0, 1, 0]))
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -108,18 +104,15 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@@ -142,7 +135,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -156,8 +149,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
|
||||
@@ -77,10 +77,9 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
|
||||
hidden_states, topk_weights, topk_ids, expert_map)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
@@ -124,64 +123,36 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1) # 注意:应为 float,不是 int
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
assist_info_for_combine = torch.arange(10) # mock 值
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": mc2_mask,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
}
|
||||
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
||||
context_metadata)
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
shared_experts = MagicMock()
|
||||
shared_experts.down_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
assist_info_for_combine = torch.arange(10)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": None,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": torch.randn(10, 128),
|
||||
"swiglu_out_scale": torch.randn(10, 1),
|
||||
}
|
||||
|
||||
self.dispatcher.shared_experts = MagicMock()
|
||||
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
|
||||
10, 128), torch.tensor(1.0))
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
|
||||
hidden_states = torch.randn(10, 128)
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
result = self.dispatcher.token_combine(hidden_states,
|
||||
context_metadata)
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
@@ -293,26 +264,35 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
@@ -438,21 +418,25 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
hidden_states = torch.randn(16, 16)
|
||||
context_metadata = {
|
||||
"input_splits": [4, 4],
|
||||
"output_splits": [4, 4],
|
||||
"topk_weights": torch.rand(8, 4),
|
||||
"reversed_local_input_permutation_mapping": torch.arange(8),
|
||||
"reversed_global_input_permutation_mapping": torch.arange(16),
|
||||
}
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
|
||||
8)
|
||||
self.dispatcher.topk_weights = torch.rand(8, 4)
|
||||
self.dispatcher.input_splits = [4, 4]
|
||||
self.dispatcher.output_splits = [4, 4]
|
||||
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
|
||||
16)
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
|
||||
[[2, 2], [2, 2]], dtype=torch.int64)
|
||||
|
||||
expert_output = torch.randn(16, 16)
|
||||
output = self.dispatcher.token_combine(expert_output)
|
||||
|
||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user