[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase):
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.NONE)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_none_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context):
mock_get_forward_context, mock_get_forward_context_2):
"""Test __call__ method when runtime mode is NONE"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
@@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase):
self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_mismatched_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context):
mock_get_forward_context,
mock_get_forward_context_2):
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
wrapper = ACLGraphWrapper(
@@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_first_time(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph for the first time"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
@@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase):
def test_call_replay_graph(self, mock_weak_ref_tensors,
mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled,
mock_torch):
"""Test __call__ method replays graph when already captured"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
@@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_check(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,
mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address checking"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
@@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_mismatch(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,
mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
@@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase):
@patch('vllm_ascend.compilation.acl_graph.patch')
def test_call_capture_graph_with_gc_disable(
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
mock_envs, mock_current_platform, mock_get_forward_context,
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with gc_disable option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable gc_disable option
@@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_with_weak_ref_output(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with weak_ref_output option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable weak_ref_output option
@@ -608,18 +626,20 @@ class TestACLGraphWrapper(TestBase):
# Should return the weak ref output when weak_ref_output option is enabled
self.assertEqual(result, "weak_ref_output")
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.logger')
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
mock_current_platform,
mock_get_forward_context):
mock_get_forward_context,mock_get_forward_context_2):
"""Test __call__ method captures graph with debug logging enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable debug logging
@@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase):
self.graph_params.events[4].append(mock_event)
self.graph_params.handles[4].append(MagicMock())
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
seq_lens = torch.tensor([4, 4])
@@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase):
forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_draft_model = False
mock_context.return_value = forward_context
num_heads = 256
scale = 0.1