### What this PR does / why we need it?
Follow https://github.com/vllm-project/vllm/pull/37425,
https://github.com/vllm-project/vllm-omni/pull/1982
Copied from them:
Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.
The original `CUDAGraphWrapper.__getattr__` raises:
```python
raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).
```Python
hasattr(self.model, "flush_pending_metadata")
→ getattr(self.model, "flush_pending_metadata")
→ not found in CUDAGraphWrapper.__dict__
→ not found in the CUDAGraphWrapper class hierarchy
→ triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
→ hasattr(self.runnable, "flush_pending_metadata") # runnable also doesn't have it
→ executes raise AttributeError(f"... {self.runnable}")
→ Python needs to construct the exception object
→ the f-string triggers self.runnable.__repr__()
→ Qwen3OmniMoeForConditionalGeneration.__repr__()
→ recursively traverses the entire nn.Module tree
→ generates a 18,000+ character string
→ takes ~6 ms
→ AttributeError object is created
→ hasattr catches the AttributeError and returns False
→ the 18,000-character string is immediately discarded (no one ever sees it)
```
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
See https://github.com/vllm-project/vllm-omni/pull/1982
- vLLM version: v0.17.0
- vLLM main:
4497431df6
---------
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
886 lines
39 KiB
Python
886 lines
39 KiB
Python
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
from vllm.forward_context import BatchDescriptor, ForwardContext
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
|
AscendMetadataForDecode)
|
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
|
AscendAttentionCPImpl
|
|
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
|
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
|
AscendMLAMetadata)
|
|
from vllm_ascend.compilation.acl_graph import (
|
|
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
|
|
set_draft_graph_params, set_graph_params,
|
|
update_draft_graph_params_workspaces)
|
|
|
|
|
|
class TestACLGraphEntry(TestBase):
|
|
|
|
def test_aclgraph_entry_initialization(self):
|
|
"""Test ACLGraphEntry initialization with default values"""
|
|
batch_descriptor = BatchDescriptor(
|
|
num_tokens=30,
|
|
uniform=False,
|
|
)
|
|
|
|
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
|
|
|
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
|
self.assertIsNone(entry.aclgraph)
|
|
self.assertIsNone(entry.output)
|
|
self.assertIsNone(entry.input_addresses)
|
|
|
|
def test_aclgraph_entry_with_values(self):
|
|
"""Test ACLGraphEntry initialization with specified values"""
|
|
batch_descriptor = BatchDescriptor(
|
|
num_tokens=30,
|
|
uniform=False,
|
|
)
|
|
|
|
mock_graph = MagicMock()
|
|
mock_output = MagicMock()
|
|
input_addresses = [12345, 67890]
|
|
|
|
entry = ACLGraphEntry(batch_descriptor=batch_descriptor,
|
|
aclgraph=mock_graph,
|
|
output=mock_output,
|
|
input_addresses=input_addresses)
|
|
|
|
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
|
self.assertEqual(entry.aclgraph, mock_graph)
|
|
self.assertEqual(entry.output, mock_output)
|
|
self.assertEqual(entry.input_addresses, input_addresses)
|
|
|
|
|
|
class TestACLGraphWrapper(TestBase):
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
super().setUp()
|
|
|
|
# Mock VllmConfig
|
|
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
self.mock_vllm_config.compilation_config = MagicMock()
|
|
|
|
# Mock runnable function
|
|
self.mock_runnable = MagicMock(return_value="test_output")
|
|
|
|
# Mock graph pool
|
|
self.mock_graph_pool = MagicMock()
|
|
|
|
# Mock CUDAGraphOptions
|
|
self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions)
|
|
self.mock_cudagraph_options.debug_log_enable = False
|
|
self.mock_cudagraph_options.gc_disable = False
|
|
self.mock_cudagraph_options.weak_ref_output = False
|
|
|
|
# Mock BatchDescriptor
|
|
self.mock_batch_descriptor = BatchDescriptor(
|
|
num_tokens=30,
|
|
uniform=False,
|
|
)
|
|
|
|
# Mock ForwardContext
|
|
self.mock_forward_context = MagicMock(spec=ForwardContext)
|
|
self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
def test_initialization_with_default_options(self, mock_envs,
|
|
mock_current_platform):
|
|
"""Test ACLGraphWrapper initialization with default CUDAGraphOptions"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
wrapper = ACLGraphWrapper(runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL)
|
|
|
|
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
|
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
|
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
|
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
|
self.assertFalse(wrapper.is_debugging_mode)
|
|
self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions)
|
|
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
def test_initialization_with_custom_options(self, mock_envs,
|
|
mock_current_platform):
|
|
"""Test ACLGraphWrapper initialization with custom CUDAGraphOptions"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
|
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
|
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
|
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
|
self.assertTrue(wrapper.is_debugging_mode)
|
|
self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options)
|
|
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
def test_initialization_assertion_error(self, mock_envs,
|
|
mock_current_platform):
|
|
"""Test ACLGraphWrapper initialization raises AssertionError for NONE mode"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
with self.assertRaises(AssertionError):
|
|
ACLGraphWrapper(runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.NONE)
|
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
def test_call_with_none_runtime_mode(self, mock_envs,
|
|
mock_current_platform,
|
|
mock_get_forward_context, mock_get_forward_context_2):
|
|
"""Test __call__ method when runtime mode is NONE"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
result = wrapper("arg1", "arg2")
|
|
|
|
# Should call the runnable directly without graph capture
|
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
|
self.assertEqual(result, "test_output")
|
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
|
mock_current_platform,
|
|
mock_get_forward_context,
|
|
mock_get_forward_context_2):
|
|
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
result = wrapper("arg1", "arg2")
|
|
|
|
# Should call the runnable directly without graph capture
|
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
|
self.assertEqual(result, "test_output")
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
@patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
)
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
def test_call_capture_graph_first_time(
|
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
|
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
"""Test __call__ method captures graph for the first time"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Mock weak_ref_tensors to return the same output
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Set up the compilation counter mock
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
# Call the wrapper
|
|
result = wrapper(test_tensor, "arg2")
|
|
|
|
# Verify graph capture happened
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
|
pool=self.mock_graph_pool)
|
|
self.mock_runnable.assert_called_once_with(test_tensor, "arg2")
|
|
|
|
# Verify the entry was created and updated
|
|
self.assertIn(self.mock_batch_descriptor,
|
|
wrapper.concrete_aclgraph_entries)
|
|
entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor]
|
|
self.assertEqual(entry.aclgraph, mock_npu_graph)
|
|
self.assertEqual(entry.output, "weak_ref_output")
|
|
|
|
# Verify compilation counter was incremented
|
|
self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1)
|
|
|
|
# Should return the original output (not weak ref)
|
|
self.assertEqual(result, "test_output")
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
@patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
)
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
|
mock_compilation_counter, mock_envs,
|
|
mock_current_platform, mock_get_forward_context,
|
|
mock_get_forward_context_2,
|
|
mock_validate_cudagraph_capturing_enabled,
|
|
mock_torch):
|
|
"""Test __call__ method replays graph when already captured"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
self.mock_forward_context.is_draft_model = False
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Mock weak_ref_tensors to return the same output
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Set up the compilation counter mock
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
# First call to capture the graph
|
|
first_result = wrapper(test_tensor, "arg2")
|
|
|
|
# Verify graph capture happened during first call
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
mock_torch.npu.graph.assert_called_once()
|
|
|
|
# Reset mock to track second call
|
|
self.mock_runnable.reset_mock()
|
|
mock_npu_graph.reset_mock()
|
|
|
|
# Second call should replay the graph
|
|
second_result = wrapper(test_tensor, "arg2")
|
|
|
|
# Verify runnable was called only during capture (not during replay)
|
|
self.mock_runnable.assert_not_called()
|
|
|
|
# Verify graph replay happened
|
|
mock_npu_graph.replay.assert_called_once()
|
|
|
|
# Both calls should return the weak ref output
|
|
self.assertEqual(first_result, "test_output") # Original output
|
|
self.assertEqual(second_result, "weak_ref_output") # Weak ref output
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
@patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
)
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
def test_call_with_debug_mode_input_address_check(
|
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
|
mock_get_forward_context,mock_get_forward_context_2,
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
"""Test __call__ method with debug mode input address checking"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
self.mock_forward_context.is_draft_model = False
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Mock weak_ref_tensors
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Create a mock tensor as the output of runnable
|
|
mock_output_tensor = torch.tensor([4, 5, 6])
|
|
self.mock_runnable.return_value = mock_output_tensor
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# First call to capture the graph
|
|
tensor = torch.tensor([1, 2, 3]) # Create tensor once
|
|
_ = wrapper(tensor, "arg2")
|
|
|
|
# Second call with same tensor addresses should work
|
|
_ = wrapper(tensor, "arg2") # Use the same tensor object
|
|
|
|
# Should not raise AssertionError
|
|
self.assertTrue(True)
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
@patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
)
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
def test_call_with_debug_mode_input_address_mismatch(
|
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
|
mock_get_forward_context,mock_get_forward_context_2,
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Mock weak_ref_tensors
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Create a mock tensor as the output of runnable
|
|
mock_output_tensor = torch.tensor([4, 5, 6])
|
|
self.mock_runnable.return_value = mock_output_tensor
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# First call to capture the graph
|
|
tensor1 = torch.tensor([1, 2, 3])
|
|
_ = wrapper(tensor1, "arg2")
|
|
|
|
# Second call with different tensor addresses should raise AssertionError
|
|
tensor2 = torch.tensor([4, 5,
|
|
6]) # Different values, different address
|
|
|
|
with self.assertRaises(AssertionError) as context:
|
|
wrapper(tensor2, "arg2")
|
|
|
|
self.assertIn("Input addresses for aclgraphs are different",
|
|
str(context.exception))
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
@patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
)
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
@patch('vllm_ascend.compilation.acl_graph.patch')
|
|
def test_call_capture_graph_with_gc_disable(
|
|
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
|
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
"""Test __call__ method captures graph with gc_disable option enabled"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
# Enable gc_disable option
|
|
self.mock_cudagraph_options.gc_disable = True
|
|
# weak_ref_output is not enabled by default
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Mock patch context manager
|
|
mock_exit_stack = MagicMock()
|
|
mock_patch.return_value = mock_exit_stack
|
|
mock_exit_stack.enter_context = Mock()
|
|
|
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
|
# 1. First call (inside the graph context) should return "inner_output"
|
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
|
mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"]
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Set up the compilation counter mock
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
# Call the wrapper
|
|
result = wrapper(test_tensor, "arg2")
|
|
|
|
# Verify patch was called to disable gc
|
|
self.assertTrue(mock_patch.called)
|
|
|
|
# Verify graph capture happened
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
|
pool=self.mock_graph_pool)
|
|
|
|
# Should return the original output (not weak ref) since weak_ref_output is not enabled
|
|
self.assertEqual(result, "test_output")
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
@patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
)
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
def test_call_capture_graph_with_weak_ref_output(
|
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
|
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
# Enable weak_ref_output option
|
|
self.mock_cudagraph_options.weak_ref_output = True
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
|
# 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output"
|
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
|
mock_weak_ref_tensors.side_effect = [
|
|
"weak_ref_output", "weak_ref_output"
|
|
]
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Set up the compilation counter mock
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
# Call the wrapper
|
|
result = wrapper(test_tensor, "arg2")
|
|
|
|
# Verify weak_ref_tensors was called twice (once for inner output, once for final output)
|
|
self.assertEqual(mock_weak_ref_tensors.call_count, 2)
|
|
|
|
# Verify graph capture happened
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
|
pool=self.mock_graph_pool)
|
|
|
|
# Should return the weak ref output when weak_ref_output option is enabled
|
|
self.assertEqual(result, "weak_ref_output")
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
@patch('vllm_ascend.compilation.acl_graph.logger')
|
|
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
|
mock_current_platform,
|
|
mock_get_forward_context,mock_get_forward_context_2):
|
|
"""Test __call__ method captures graph with debug logging enabled"""
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
# Enable debug logging
|
|
self.mock_cudagraph_options.debug_log_enable = True
|
|
# weak_ref_output is not enabled by default
|
|
|
|
# Mock torch
|
|
with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch:
|
|
# Mock torch.npu.NPUGraph
|
|
mock_npu_graph = MagicMock()
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
# Mock torch.npu.graph context manager
|
|
mock_graph_context = MagicMock()
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
# Mock weak_ref_tensors
|
|
with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors'
|
|
) as mock_weak_ref_tensors:
|
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
|
# 1. First call (inside the graph context) should return "inner_output"
|
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
|
mock_weak_ref_tensors.side_effect = [
|
|
"inner_output", "weak_ref_output"
|
|
]
|
|
|
|
# Mock validate_cudagraph_capturing_enabled
|
|
with patch(
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
):
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
# Call the wrapper
|
|
_ = wrapper(test_tensor, "arg2")
|
|
|
|
# Verify debug log was called
|
|
mock_logger.debug.assert_called_once()
|
|
|
|
def test_getattr_access_runnable_attributes(self):
|
|
"""Test __getattr__ method accesses runnable attributes"""
|
|
mock_runnable = MagicMock()
|
|
mock_runnable.test_attr = "test_value"
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Should be able to access attributes of the runnable
|
|
self.assertEqual(wrapper.test_attr, "test_value")
|
|
|
|
def test_getattr_attribute_not_exists(self):
|
|
"""Test __getattr__ method raises AttributeError for non-existent attributes"""
|
|
|
|
# Create a simple object without any attributes
|
|
class EmptyRunnable:
|
|
pass
|
|
|
|
mock_runnable = EmptyRunnable()
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
# Should raise AttributeError for non-existent attributes
|
|
with self.assertRaises(AttributeError) as context:
|
|
_ = wrapper.non_existent_attr
|
|
|
|
self.assertIn("Attribute non_existent_attr not found",
|
|
str(context.exception))
|
|
|
|
def test_unwrap_method(self):
|
|
"""Test unwrap method returns the original runnable"""
|
|
wrapper = ACLGraphWrapper(
|
|
runnable=self.mock_runnable,
|
|
vllm_config=self.mock_vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
unwrapped = wrapper.unwrap()
|
|
self.assertEqual(unwrapped, self.mock_runnable)
|
|
|
|
|
|
class TestDraftGraphParams(TestBase):
|
|
|
|
def test_set_draft_graph_params(self):
|
|
with patch('vllm_ascend.compilation.acl_graph._draft_graph_params',
|
|
new=None):
|
|
set_draft_graph_params([4])
|
|
from vllm_ascend.compilation.acl_graph import _draft_graph_params
|
|
self.assertIsNotNone(_draft_graph_params)
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
|
|
def test_update_draft_graph_params_workspaces(self,
|
|
draft_graph_params_mock):
|
|
draft_graph_params_mock.workspaces = {4: 5}
|
|
update_draft_graph_params_workspaces(4, 6)
|
|
self.assertEqual(draft_graph_params_mock.workspaces[4], 6)
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
|
|
def test_get_draft_graph_params(self, draft_graph_params_mock):
|
|
graph_params = get_draft_graph_params()
|
|
self.assertIs(draft_graph_params_mock, graph_params)
|
|
|
|
|
|
class TestPCPDCPGraphParams(TestBase):
|
|
|
|
def setUp(self):
|
|
self.update_stream = MagicMock(name="FakeStream")
|
|
graph_params = get_graph_params()
|
|
if graph_params is None:
|
|
set_graph_params(set([4]))
|
|
self.graph_params = get_graph_params()
|
|
else:
|
|
self.graph_params = graph_params
|
|
mock_event = torch.npu.ExternalEvent()
|
|
mock_event.record = MagicMock()
|
|
self.graph_params.events[4] = []
|
|
self.graph_params.handles[4] = []
|
|
self.graph_params.events[4].append(mock_event)
|
|
self.graph_params.handles[4].append(MagicMock())
|
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch('torch.npu.graph_task_update_end', )
|
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
|
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
|
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
|
|
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
|
seq_lens = torch.tensor([4, 4])
|
|
cp_seq_len = torch.tensor([2, 2])
|
|
max_seq_lens = 4
|
|
seq_lens_list = [4, 4]
|
|
slot_mapping = torch.zeros(8, dtype=torch.long)
|
|
query_start_loc = torch.tensor([0, 4])
|
|
block_tables = torch.zeros(2, 5, dtype=torch.long)
|
|
|
|
decode = AscendMLADecodeMetadata(input_positions,
|
|
block_table,
|
|
seq_lens,
|
|
max_seq_lens,
|
|
seq_lens_list,
|
|
cp_seq_len=cp_seq_len)
|
|
metadata = AscendMLAMetadata(8,
|
|
8,
|
|
slot_mapping,
|
|
query_start_loc,
|
|
seq_lens,
|
|
block_tables,
|
|
4,
|
|
4,
|
|
0,
|
|
decode=decode)
|
|
forward_context = MagicMock()
|
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
|
forward_context.is_draft_model = False
|
|
mock_context.return_value = forward_context
|
|
|
|
num_heads = 256
|
|
scale = 0.1
|
|
num_kv_heads = 8
|
|
qk_head_dim = 96
|
|
qk_rope_head_dim = 32
|
|
qk_nope_head_dim = 64
|
|
query = torch.randn(4, num_heads, qk_head_dim)
|
|
|
|
q_nope = query[..., :qk_nope_head_dim]
|
|
q_pe = query[..., qk_rope_head_dim:]
|
|
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
|
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
|
|
input_layout = "BNSD"
|
|
actual_seq_lengths_kv = [1, 1]
|
|
out = torch.randn(2, 16, 128)
|
|
lse = torch.randn(2, 16, 8)
|
|
self.graph_params.attn_params[4] = []
|
|
self.graph_params.attn_params[4].append(
|
|
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
|
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
|
|
out, lse))
|
|
|
|
with patch("torch_npu._C._npu_setStream", return_value=None):
|
|
AscendMlaCPImpl.update_graph_params(
|
|
self.update_stream, forward_context, 4
|
|
)
|
|
|
|
_mock_graph_task_end.assert_called_once()
|
|
|
|
@patch('torch.npu.graph_task_update_end', )
|
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
|
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
|
def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
|
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
|
num_heads = 256
|
|
scale = 0.1
|
|
num_kv_heads = 8
|
|
qk_head_dim = 96
|
|
qk_nope_head_dim = 64
|
|
query = torch.randn(4, num_heads, qk_head_dim)
|
|
q_nope = query[..., :qk_nope_head_dim]
|
|
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
|
actual_seq_lengths_kv = [1, 1]
|
|
actual_seq_lengths_q = np.array([1, 1])
|
|
out = torch.randn(2, 16, 128)
|
|
lse = torch.randn(2, 16, 8)
|
|
|
|
num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]],
|
|
[[1, 1], [1, 1]]])
|
|
decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp)
|
|
metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1],
|
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
|
num_decode_tokens=1,
|
|
decode_meta=decode)
|
|
forward_context = MagicMock()
|
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
|
forward_context.is_draft_model = False
|
|
|
|
self.graph_params.attn_params[4] = []
|
|
self.graph_params.attn_params[4].append(
|
|
(q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale,
|
|
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
|
|
out, lse, 2, 0, 0))
|
|
|
|
with patch("torch_npu._C._npu_setStream", return_value=None):
|
|
AscendAttentionCPImpl.update_graph_params(
|
|
self.update_stream, forward_context, 4, None
|
|
)
|
|
|
|
_mock_graph_task_end.assert_called_once()
|