[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)

### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-12-23 08:49:52 +08:00
committed by GitHub
parent 95e8a52156
commit c3a8d13ca7
10 changed files with 55 additions and 83 deletions

View File

@@ -92,15 +92,14 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method.finalize.side_effect = mock_finalize
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
mock_weight_prefetch_method = MagicMock()
mock_forward_context_obj = MagicMock(
moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False,
weight_prefetch_method=mock_weight_prefetch_method)
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
mc2_mask=torch.zeros(
16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
@@ -133,8 +132,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=mock_forward_context_obj):
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=mock_weight_prefetch_method):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
@@ -590,4 +589,4 @@ class TestUnifiedApplyMLP(TestBase):
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.dtype, torch.bfloat16)