[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -74,7 +74,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_flash_attention")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||
def test_forward_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
||||
):
|
||||
@@ -105,7 +105,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||
def test_forward_chunked_prefill_310(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
@@ -140,7 +140,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||
def test_forward_prefill_cache_hit_310(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
@@ -175,7 +175,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
@patch("vllm_ascend.attention.attention_v1.using_paged_attention")
|
||||
@patch("torch_npu._npu_paged_attention")
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||
def test_forward_paged_attention_310(
|
||||
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
|
||||
):
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch(
|
||||
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
|
||||
'vllm_ascend.ascend_forward_context.get_forward_context'
|
||||
)
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
||||
|
||||
@@ -212,7 +212,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
def test_forward_fused_infer_attention(
|
||||
self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
|
||||
@@ -248,7 +248,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
def test_forward_paged_attention(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention,
|
||||
@@ -279,7 +279,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (4, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
|
||||
@@ -311,7 +311,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
|
||||
@@ -449,7 +449,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||
|
||||
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
|
||||
@@ -929,7 +929,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode_without_graph(self,
|
||||
@@ -1095,7 +1095,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||
mock_get_forward_context):
|
||||
|
||||
@@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase):
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.NONE)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_call_with_none_runtime_mode(self, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context, mock_get_forward_context_2):
|
||||
"""Test __call__ method when runtime mode is NONE"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
@@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase):
|
||||
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||
self.assertEqual(result, "test_output")
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_forward_context_2):
|
||||
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
@@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_capture_graph_first_time(
|
||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph for the first time"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
@@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase):
|
||||
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
||||
mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled,
|
||||
mock_torch):
|
||||
"""Test __call__ method replays graph when already captured"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
self.mock_forward_context.is_draft_model = False
|
||||
|
||||
@@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_with_debug_mode_input_address_check(
|
||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||
mock_get_forward_context,
|
||||
mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method with debug mode input address checking"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
self.mock_forward_context.is_draft_model = False
|
||||
|
||||
@@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_with_debug_mode_input_address_mismatch(
|
||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||
mock_get_forward_context,
|
||||
mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
@@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase):
|
||||
@patch('vllm_ascend.compilation.acl_graph.patch')
|
||||
def test_call_capture_graph_with_gc_disable(
|
||||
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
||||
mock_envs, mock_current_platform, mock_get_forward_context,
|
||||
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph with gc_disable option enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable gc_disable option
|
||||
@@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_capture_graph_with_weak_ref_output(
|
||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable weak_ref_output option
|
||||
@@ -608,18 +626,20 @@ class TestACLGraphWrapper(TestBase):
|
||||
|
||||
# Should return the weak ref output when weak_ref_output option is enabled
|
||||
self.assertEqual(result, "weak_ref_output")
|
||||
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.logger')
|
||||
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,mock_get_forward_context_2):
|
||||
"""Test __call__ method captures graph with debug logging enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable debug logging
|
||||
@@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
self.graph_params.events[4].append(mock_event)
|
||||
self.graph_params.handles[4].append(MagicMock())
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('torch.npu.graph_task_update_end', )
|
||||
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
|
||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
|
||||
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||
seq_lens = torch.tensor([4, 4])
|
||||
@@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
forward_context = MagicMock()
|
||||
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||
forward_context.is_draft_model = False
|
||||
mock_context.return_value = forward_context
|
||||
|
||||
num_heads = 256
|
||||
scale = 0.1
|
||||
|
||||
@@ -119,11 +119,9 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
|
||||
patch('vllm_ascend.ascend_forward_context.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
|
||||
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
@@ -298,7 +296,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
|
||||
return_value=MagicMock())
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType.A3)
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@@ -407,7 +405,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
|
||||
return_value=MagicMock())
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -513,7 +511,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
|
||||
return_value=MagicMock())
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
|
||||
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
|
||||
@@ -121,9 +121,10 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
@patch("vllm_ascend.ops.mla.get_ascend_config")
|
||||
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
|
||||
@patch("vllm_ascend.ops.mla.get_forward_context")
|
||||
def test_forward(self, mock_get_forward_context, mock_tp_size,
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
def test_forward(self, mock_get_forward_context_2, mock_get_forward_context, mock_tp_size,
|
||||
mock_ascend_config, mock_get_vllm_config,
|
||||
mock_mla_forward):
|
||||
mock_mla_forward,):
|
||||
mock_tp_size.return_value = 1
|
||||
mock_ascend_config.return_value.enable_shared_expert_dp = False
|
||||
mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
@@ -159,6 +160,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
mock_forward_context = MagicMock(spec=ForwardContext)
|
||||
mock_forward_context.flash_comm_v1_enabled = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
mock_get_forward_context_2.return_value = mock_forward_context
|
||||
|
||||
mock_mla_forward.return_value = (3, self.hidden_size)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.global_redundant_expert_num = 0
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
@@ -116,7 +116,7 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||
)
|
||||
@@ -155,7 +155,7 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
@@ -65,7 +65,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||
mock_get_forward_context, mock_tp_rank,
|
||||
@@ -169,7 +169,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||
|
||||
@@ -386,10 +386,11 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
set_current_vllm_config(None)
|
||||
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||
**{"return_value.flash_comm_v1_enabled": False})
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
||||
def test_dummy_run_basic(self, mock_context, mock_get_context, mock_get_context_2):
|
||||
num_tokens = 32
|
||||
with_prefill = False
|
||||
|
||||
@@ -402,10 +403,11 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||
**{"return_value.flash_comm_v1_enabled": False})
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
||||
def test_dummy_run_with_prefill(self, mock_context, mock_get_context, mock_get_context_2):
|
||||
mock_context.return_value.__enter__.return_value = None
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
@@ -413,11 +415,12 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
|
||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
|
||||
mock_update_full_graph_params):
|
||||
mock_update_full_graph_params, mock_get_context_2):
|
||||
last_use_cuda_graph = self.proposer.use_cuda_graph
|
||||
mock_return_context = MagicMock()
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
@@ -425,6 +428,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
mock_return_context.flash_comm_v1_enabled = False
|
||||
mock_get_context.return_value = mock_return_context
|
||||
mock_get_context_2.return_value = mock_return_context
|
||||
self.proposer.use_cuda_graph = True
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
@@ -435,12 +439,13 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||
mock_update_full_graph_params.assert_not_called()
|
||||
self.proposer.use_cuda_graph = last_use_cuda_graph
|
||||
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
|
||||
mock_update_full_graph_params):
|
||||
mock_update_full_graph_params, mock_get_context_2):
|
||||
last_use_cuda_graph = self.proposer.use_cuda_graph
|
||||
mock_return_context = MagicMock()
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
@@ -448,6 +453,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
mock_return_context.flash_comm_v1_enabled = False
|
||||
mock_get_context.return_value = mock_return_context
|
||||
mock_get_context_2.return_value = mock_return_context
|
||||
self.proposer.use_cuda_graph = True
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
|
||||
Reference in New Issue
Block a user