Remove unused row_idx in token_dispatcher (#3442)

### What this PR does / why we need it?
The `row_idx` parameter is no longer used since
PR[#2689](https://github.com/vllm-project/vllm-ascend/pull/2689), so
remove it across multiple files to remove unnecessary calculations and
parameter passing.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
CaranLic
2025-10-15 09:08:31 +08:00
committed by GitHub
parent 3642b64afc
commit 15b2e5c995
11 changed files with 37 additions and 88 deletions

View File

@@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
@@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
@@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
@@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
@@ -297,7 +283,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
@@ -318,7 +304,6 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
gc.collect()
torch.npu.empty_cache()