[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase):
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.NONE)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_call_with_none_runtime_mode(self, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context, mock_get_forward_context_2):
|
||||
"""Test __call__ method when runtime mode is NONE"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
@@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase):
|
||||
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||
self.assertEqual(result, "test_output")
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_forward_context_2):
|
||||
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
@@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_capture_graph_first_time(
|
||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph for the first time"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
@@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase):
|
||||
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
||||
mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled,
|
||||
mock_torch):
|
||||
"""Test __call__ method replays graph when already captured"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
self.mock_forward_context.is_draft_model = False
|
||||
|
||||
@@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_with_debug_mode_input_address_check(
|
||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||
mock_get_forward_context,
|
||||
mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method with debug mode input address checking"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
self.mock_forward_context.is_draft_model = False
|
||||
|
||||
@@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_with_debug_mode_input_address_mismatch(
|
||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||
mock_get_forward_context,
|
||||
mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
@@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase):
|
||||
@patch('vllm_ascend.compilation.acl_graph.patch')
|
||||
def test_call_capture_graph_with_gc_disable(
|
||||
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
||||
mock_envs, mock_current_platform, mock_get_forward_context,
|
||||
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph with gc_disable option enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable gc_disable option
|
||||
@@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase):
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_capture_graph_with_weak_ref_output(
|
||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable weak_ref_output option
|
||||
@@ -608,18 +626,20 @@ class TestACLGraphWrapper(TestBase):
|
||||
|
||||
# Should return the weak ref output when weak_ref_output option is enabled
|
||||
self.assertEqual(result, "weak_ref_output")
|
||||
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.logger')
|
||||
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,mock_get_forward_context_2):
|
||||
"""Test __call__ method captures graph with debug logging enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable debug logging
|
||||
@@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
self.graph_params.events[4].append(mock_event)
|
||||
self.graph_params.handles[4].append(MagicMock())
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch('torch.npu.graph_task_update_end', )
|
||||
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
|
||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
|
||||
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||
seq_lens = torch.tensor([4, 4])
|
||||
@@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
forward_context = MagicMock()
|
||||
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||
forward_context.is_draft_model = False
|
||||
mock_context.return_value = forward_context
|
||||
|
||||
num_heads = 256
|
||||
scale = 0.1
|
||||
|
||||
Reference in New Issue
Block a user