Remove unused row_idx in token_dispatcher (#3442)
### What this PR does / why we need it? The `row_idx` parameter is no longer used since PR[#2689](https://github.com/vllm-project/vllm-ascend/pull/2689), so remove it across multiple files to remove unnecessary calculations and parameter passing. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
@@ -58,7 +58,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
||||
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.mc2_group_patch.stop()
|
||||
@@ -96,7 +95,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
(None, None)) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, expert_map)
|
||||
expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
0) # group_list_type == 0
|
||||
@@ -117,7 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
self.row_idx,
|
||||
torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
@@ -181,7 +179,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
self.patcher_npu_moe_token_unpermute = patch(
|
||||
'torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||
@@ -198,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
topk_ids, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
@@ -213,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
topk_ids, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
@@ -237,7 +234,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
@@ -258,7 +255,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
self.row_idx,
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
@@ -401,7 +397,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=False)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def test_token_dispatch(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
@@ -416,7 +411,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
@@ -463,7 +457,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
@@ -492,7 +485,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
@@ -515,7 +507,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user