[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)
### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -92,15 +92,14 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
mock_weight_prefetch_method = MagicMock()
|
||||
mock_forward_context_obj = MagicMock(
|
||||
moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False,
|
||||
weight_prefetch_method=mock_weight_prefetch_method)
|
||||
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(
|
||||
16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
@@ -133,8 +132,8 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||
return_value=mock_weight_prefetch_method):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
@@ -590,4 +589,4 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states_shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user