Remove unused row_idx in token_dispatcher (#3442)

### What this PR does / why we need it?
The `row_idx` parameter is no longer used since
PR[#2689](https://github.com/vllm-project/vllm-ascend/pull/2689), so
remove it across multiple files to remove unnecessary calculations and
parameter passing.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
CaranLic
2025-10-15 09:08:31 +08:00
committed by GitHub
parent 3642b64afc
commit 15b2e5c995
11 changed files with 37 additions and 88 deletions

View File

@@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
@@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
@@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
@@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
@@ -297,7 +283,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
@@ -318,7 +304,6 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
gc.collect()
torch.npu.empty_cache()

View File

@@ -263,7 +263,7 @@ class TestExpertsSelector:
x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids, _ = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,

View File

@@ -204,7 +204,6 @@ class TestMoECommMethod(TestBase):
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
[0.6, 0.4]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
row_idx = torch.arange(4)
# Make sure tensors are contiguous and have correct strides
hidden_states = hidden_states.contiguous()
@@ -216,7 +215,6 @@ class TestMoECommMethod(TestBase):
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
activation="silu")
# Verify result shape

View File

@@ -58,7 +58,6 @@ class TestTokenDispatcherWithMC2(TestBase):
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.mc2_group_patch.stop()
@@ -96,7 +95,7 @@ class TestTokenDispatcherWithMC2(TestBase):
(None, None)) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, expert_map)
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
0) # group_list_type == 0
@@ -117,7 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
@@ -181,7 +179,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.row_idx = torch.arange(10, dtype=torch.int32)
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
@@ -198,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
topk_ids, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
@@ -213,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
topk_ids, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
@@ -237,7 +234,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
None)
self.assertEqual(results["group_list_type"], 1)
@@ -258,7 +255,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights,
topk_ids,
self.row_idx,
None,
with_quant=True)
@@ -401,7 +397,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
@@ -416,7 +411,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -463,7 +457,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
@@ -492,7 +485,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
@@ -515,7 +507,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)

View File

@@ -777,12 +777,12 @@ class TestSelectExperts(TestBase):
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -790,12 +790,12 @@ class TestSelectExperts(TestBase):
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -818,13 +818,13 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.long))
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -838,7 +838,7 @@ class TestSelectExperts(TestBase):
self.num_experts)
e_score_correction_bias = torch.randn(self.num_experts)
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
@@ -861,7 +861,7 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.int32))
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
@@ -888,7 +888,7 @@ class TestSelectExperts(TestBase):
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
@@ -914,7 +914,7 @@ class TestSelectExperts(TestBase):
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,