[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)

### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-12-23 08:49:52 +08:00
committed by GitHub
parent 95e8a52156
commit c3a8d13ca7
10 changed files with 55 additions and 83 deletions

View File

@@ -286,8 +286,8 @@ def test_select_experts(
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())):
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
@@ -323,8 +323,8 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()), \
pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
@@ -336,17 +336,3 @@ def test_select_experts_invalid_scoring_func(device: str):
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()