Remove unused row_idx in token_dispatcher (#3442)
### What this PR does / why we need it? The `row_idx` parameter is no longer used since PR[#2689](https://github.com/vllm-project/vllm-ascend/pull/2689), so remove it across multiple files to remove unnecessary calculations and parameter passing. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
@@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
@@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
@@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
@@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=True)
|
||||
@@ -297,7 +283,7 @@ def test_select_experts(
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=topk,
|
||||
@@ -318,7 +304,6 @@ def test_select_experts(
|
||||
assert topk_weights.shape == (m, topk)
|
||||
assert topk_ids.shape == (m, topk)
|
||||
assert topk_ids.dtype == torch.int32
|
||||
assert row_idx.shape == (m, topk)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user