[Test]Add unit test for compilation/acl_graph.py (#3039)
### What this PR does / why we need it?
According to issue [#1298
](https://github.com/vllm-project/vllm-ascend/issues/1298) ,this pull
request adds unit test code for compilation/acl_graph.py.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.2
- vLLM main:
f2718d2948
---------
Signed-off-by: zhanghaiwen <zhanghaiwen@cmss.chinamobile.com>
Co-authored-by: zhanghaiwen <zhanghaiwen@cmss.chinamobile.com>
This commit is contained in:
720
tests/ut/compilation/test_acl_graph.py
Normal file
720
tests/ut/compilation/test_acl_graph.py
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||||
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.forward_context import BatchDescriptor, ForwardContext
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class TestACLGraphEntry(TestBase):
|
||||||
|
|
||||||
|
def test_aclgraph_entry_initialization(self):
|
||||||
|
"""Test ACLGraphEntry initialization with default values"""
|
||||||
|
batch_descriptor = BatchDescriptor(
|
||||||
|
num_tokens=30,
|
||||||
|
uniform_decode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||||
|
|
||||||
|
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
||||||
|
self.assertIsNone(entry.aclgraph)
|
||||||
|
self.assertIsNone(entry.output)
|
||||||
|
self.assertIsNone(entry.input_addresses)
|
||||||
|
|
||||||
|
def test_aclgraph_entry_with_values(self):
|
||||||
|
"""Test ACLGraphEntry initialization with specified values"""
|
||||||
|
batch_descriptor = BatchDescriptor(
|
||||||
|
num_tokens=30,
|
||||||
|
uniform_decode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_output = MagicMock()
|
||||||
|
input_addresses = [12345, 67890]
|
||||||
|
|
||||||
|
entry = ACLGraphEntry(batch_descriptor=batch_descriptor,
|
||||||
|
aclgraph=mock_graph,
|
||||||
|
output=mock_output,
|
||||||
|
input_addresses=input_addresses)
|
||||||
|
|
||||||
|
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
||||||
|
self.assertEqual(entry.aclgraph, mock_graph)
|
||||||
|
self.assertEqual(entry.output, mock_output)
|
||||||
|
self.assertEqual(entry.input_addresses, input_addresses)
|
||||||
|
|
||||||
|
|
||||||
|
class TestACLGraphWrapper(TestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures"""
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# Mock VllmConfig
|
||||||
|
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||||
|
self.mock_vllm_config.compilation_config = MagicMock()
|
||||||
|
|
||||||
|
# Mock runnable function
|
||||||
|
self.mock_runnable = MagicMock(return_value="test_output")
|
||||||
|
|
||||||
|
# Mock graph pool
|
||||||
|
self.mock_graph_pool = MagicMock()
|
||||||
|
|
||||||
|
# Mock CUDAGraphOptions
|
||||||
|
self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions)
|
||||||
|
self.mock_cudagraph_options.debug_log_enable = False
|
||||||
|
self.mock_cudagraph_options.gc_disable = False
|
||||||
|
self.mock_cudagraph_options.weak_ref_output = False
|
||||||
|
|
||||||
|
# Mock BatchDescriptor
|
||||||
|
self.mock_batch_descriptor = BatchDescriptor(
|
||||||
|
num_tokens=30,
|
||||||
|
uniform_decode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock ForwardContext
|
||||||
|
self.mock_forward_context = MagicMock(spec=ForwardContext)
|
||||||
|
self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
def test_initialization_with_default_options(self, mock_envs,
|
||||||
|
mock_current_platform):
|
||||||
|
"""Test ACLGraphWrapper initialization with default CUDAGraphOptions"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool)
|
||||||
|
|
||||||
|
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
||||||
|
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
||||||
|
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
||||||
|
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
||||||
|
self.assertFalse(wrapper.is_debugging_mode)
|
||||||
|
self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions)
|
||||||
|
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
def test_initialization_with_custom_options(self, mock_envs,
|
||||||
|
mock_current_platform):
|
||||||
|
"""Test ACLGraphWrapper initialization with custom CUDAGraphOptions"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
||||||
|
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
||||||
|
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
||||||
|
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
||||||
|
self.assertTrue(wrapper.is_debugging_mode)
|
||||||
|
self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options)
|
||||||
|
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
def test_initialization_assertion_error(self, mock_envs,
|
||||||
|
mock_current_platform):
|
||||||
|
"""Test ACLGraphWrapper initialization raises AssertionError for NONE mode"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
ACLGraphWrapper(runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.NONE,
|
||||||
|
graph_pool=self.mock_graph_pool)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
def test_call_with_none_runtime_mode(self, mock_envs,
|
||||||
|
mock_current_platform,
|
||||||
|
mock_get_forward_context):
|
||||||
|
"""Test __call__ method when runtime mode is NONE"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
result = wrapper("arg1", "arg2")
|
||||||
|
|
||||||
|
# Should call the runnable directly without graph capture
|
||||||
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||||
|
self.assertEqual(result, "test_output")
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
||||||
|
mock_current_platform,
|
||||||
|
mock_get_forward_context):
|
||||||
|
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
result = wrapper("arg1", "arg2")
|
||||||
|
|
||||||
|
# Should call the runnable directly without graph capture
|
||||||
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||||
|
self.assertEqual(result, "test_output")
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||||
|
@patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
)
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
|
def test_call_capture_graph_first_time(
|
||||||
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||||
|
mock_current_platform, mock_get_forward_context,
|
||||||
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
|
"""Test __call__ method captures graph for the first time"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors to return the same output
|
||||||
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Set up the compilation counter mock
|
||||||
|
mock_compilation_counter.num_cudagraph_captured = 0
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Create a real torch tensor for the test, not a mock
|
||||||
|
test_tensor = torch.tensor([1, 2, 3])
|
||||||
|
|
||||||
|
# Call the wrapper
|
||||||
|
result = wrapper(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify graph capture happened
|
||||||
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||||
|
mock_torch.npu.NPUGraph.assert_called_once()
|
||||||
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
||||||
|
pool=self.mock_graph_pool)
|
||||||
|
self.mock_runnable.assert_called_once_with(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify the entry was created and updated
|
||||||
|
self.assertIn(self.mock_batch_descriptor,
|
||||||
|
wrapper.concrete_aclgraph_entries)
|
||||||
|
entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor]
|
||||||
|
self.assertEqual(entry.aclgraph, mock_npu_graph)
|
||||||
|
self.assertEqual(entry.output, "weak_ref_output")
|
||||||
|
|
||||||
|
# Verify compilation counter was incremented
|
||||||
|
self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1)
|
||||||
|
|
||||||
|
# Should return the original output (not weak ref)
|
||||||
|
self.assertEqual(result, "test_output")
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||||
|
@patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
)
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
|
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
||||||
|
mock_compilation_counter, mock_envs,
|
||||||
|
mock_current_platform, mock_get_forward_context,
|
||||||
|
mock_validate_cudagraph_capturing_enabled,
|
||||||
|
mock_torch):
|
||||||
|
"""Test __call__ method replays graph when already captured"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors to return the same output
|
||||||
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Set up the compilation counter mock
|
||||||
|
mock_compilation_counter.num_cudagraph_captured = 0
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Create a real torch tensor for the test, not a mock
|
||||||
|
test_tensor = torch.tensor([1, 2, 3])
|
||||||
|
|
||||||
|
# First call to capture the graph
|
||||||
|
first_result = wrapper(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify graph capture happened during first call
|
||||||
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||||
|
mock_torch.npu.NPUGraph.assert_called_once()
|
||||||
|
mock_torch.npu.graph.assert_called_once()
|
||||||
|
|
||||||
|
# Reset mock to track second call
|
||||||
|
self.mock_runnable.reset_mock()
|
||||||
|
mock_npu_graph.reset_mock()
|
||||||
|
|
||||||
|
# Second call should replay the graph
|
||||||
|
second_result = wrapper(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify runnable was called only during capture (not during replay)
|
||||||
|
self.mock_runnable.assert_not_called()
|
||||||
|
|
||||||
|
# Verify graph replay happened
|
||||||
|
mock_npu_graph.replay.assert_called_once()
|
||||||
|
|
||||||
|
# Both calls should return the weak ref output
|
||||||
|
self.assertEqual(first_result, "test_output") # Original output
|
||||||
|
self.assertEqual(second_result, "weak_ref_output") # Weak ref output
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||||
|
@patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
)
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
|
def test_call_with_debug_mode_input_address_check(
|
||||||
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||||
|
mock_get_forward_context,
|
||||||
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
|
"""Test __call__ method with debug mode input address checking"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors
|
||||||
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Create a mock tensor as the output of runnable
|
||||||
|
mock_output_tensor = torch.tensor([4, 5, 6])
|
||||||
|
self.mock_runnable.return_value = mock_output_tensor
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# First call to capture the graph
|
||||||
|
tensor = torch.tensor([1, 2, 3]) # Create tensor once
|
||||||
|
_ = wrapper(tensor, "arg2")
|
||||||
|
|
||||||
|
# Second call with same tensor addresses should work
|
||||||
|
_ = wrapper(tensor, "arg2") # Use the same tensor object
|
||||||
|
|
||||||
|
# Should not raise AssertionError
|
||||||
|
self.assertTrue(True)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||||
|
@patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
)
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
|
def test_call_with_debug_mode_input_address_mismatch(
|
||||||
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||||
|
mock_get_forward_context,
|
||||||
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
|
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors
|
||||||
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Create a mock tensor as the output of runnable
|
||||||
|
mock_output_tensor = torch.tensor([4, 5, 6])
|
||||||
|
self.mock_runnable.return_value = mock_output_tensor
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# First call to capture the graph
|
||||||
|
tensor1 = torch.tensor([1, 2, 3])
|
||||||
|
_ = wrapper(tensor1, "arg2")
|
||||||
|
|
||||||
|
# Second call with different tensor addresses should raise AssertionError
|
||||||
|
tensor2 = torch.tensor([4, 5,
|
||||||
|
6]) # Different values, different address
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError) as context:
|
||||||
|
wrapper(tensor2, "arg2")
|
||||||
|
|
||||||
|
self.assertIn("Input addresses for aclgraphs are different",
|
||||||
|
str(context.exception))
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||||
|
@patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
)
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.patch')
|
||||||
|
def test_call_capture_graph_with_gc_disable(
|
||||||
|
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
||||||
|
mock_envs, mock_current_platform, mock_get_forward_context,
|
||||||
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
|
"""Test __call__ method captures graph with gc_disable option enabled"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Enable gc_disable option
|
||||||
|
self.mock_cudagraph_options.gc_disable = True
|
||||||
|
# weak_ref_output is not enabled by default
|
||||||
|
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock patch context manager
|
||||||
|
mock_exit_stack = MagicMock()
|
||||||
|
mock_patch.return_value = mock_exit_stack
|
||||||
|
mock_exit_stack.enter_context = Mock()
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
||||||
|
# 1. First call (inside the graph context) should return "inner_output"
|
||||||
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
||||||
|
mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"]
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Set up the compilation counter mock
|
||||||
|
mock_compilation_counter.num_cudagraph_captured = 0
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Create a real torch tensor for the test, not a mock
|
||||||
|
test_tensor = torch.tensor([1, 2, 3])
|
||||||
|
|
||||||
|
# Call the wrapper
|
||||||
|
result = wrapper(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify patch was called to disable gc
|
||||||
|
self.assertTrue(mock_patch.called)
|
||||||
|
|
||||||
|
# Verify graph capture happened
|
||||||
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||||
|
mock_torch.npu.NPUGraph.assert_called_once()
|
||||||
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
||||||
|
pool=self.mock_graph_pool)
|
||||||
|
|
||||||
|
# Should return the original output (not weak ref) since weak_ref_output is not enabled
|
||||||
|
self.assertEqual(result, "test_output")
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||||
|
@patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
)
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
|
def test_call_capture_graph_with_weak_ref_output(
|
||||||
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||||
|
mock_current_platform, mock_get_forward_context,
|
||||||
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
|
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Enable weak_ref_output option
|
||||||
|
self.mock_cudagraph_options.weak_ref_output = True
|
||||||
|
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
||||||
|
# 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output"
|
||||||
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
||||||
|
mock_weak_ref_tensors.side_effect = [
|
||||||
|
"weak_ref_output", "weak_ref_output"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Set up the compilation counter mock
|
||||||
|
mock_compilation_counter.num_cudagraph_captured = 0
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Create a real torch tensor for the test, not a mock
|
||||||
|
test_tensor = torch.tensor([1, 2, 3])
|
||||||
|
|
||||||
|
# Call the wrapper
|
||||||
|
result = wrapper(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify weak_ref_tensors was called twice (once for inner output, once for final output)
|
||||||
|
self.assertEqual(mock_weak_ref_tensors.call_count, 2)
|
||||||
|
|
||||||
|
# Verify graph capture happened
|
||||||
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||||
|
mock_torch.npu.NPUGraph.assert_called_once()
|
||||||
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
||||||
|
pool=self.mock_graph_pool)
|
||||||
|
|
||||||
|
# Should return the weak ref output when weak_ref_output option is enabled
|
||||||
|
self.assertEqual(result, "weak_ref_output")
|
||||||
|
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
|
@patch('vllm_ascend.compilation.acl_graph.logger')
|
||||||
|
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
||||||
|
mock_current_platform,
|
||||||
|
mock_get_forward_context):
|
||||||
|
"""Test __call__ method captures graph with debug logging enabled"""
|
||||||
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# Enable debug logging
|
||||||
|
self.mock_cudagraph_options.debug_log_enable = True
|
||||||
|
# weak_ref_output is not enabled by default
|
||||||
|
|
||||||
|
# Mock torch
|
||||||
|
with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch:
|
||||||
|
# Mock torch.npu.NPUGraph
|
||||||
|
mock_npu_graph = MagicMock()
|
||||||
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||||
|
|
||||||
|
# Mock torch.npu.graph context manager
|
||||||
|
mock_graph_context = MagicMock()
|
||||||
|
mock_torch.npu.graph.return_value = mock_graph_context
|
||||||
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||||
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||||
|
mock_torch.Tensor = torch.Tensor
|
||||||
|
|
||||||
|
# Mock weak_ref_tensors
|
||||||
|
with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors'
|
||||||
|
) as mock_weak_ref_tensors:
|
||||||
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
||||||
|
# 1. First call (inside the graph context) should return "inner_output"
|
||||||
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
||||||
|
mock_weak_ref_tensors.side_effect = [
|
||||||
|
"inner_output", "weak_ref_output"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock validate_cudagraph_capturing_enabled
|
||||||
|
with patch(
|
||||||
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
|
):
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Create a real torch tensor for the test, not a mock
|
||||||
|
test_tensor = torch.tensor([1, 2, 3])
|
||||||
|
|
||||||
|
# Call the wrapper
|
||||||
|
_ = wrapper(test_tensor, "arg2")
|
||||||
|
|
||||||
|
# Verify debug log was called
|
||||||
|
mock_logger.debug.assert_called_once()
|
||||||
|
|
||||||
|
def test_getattr_access_runnable_attributes(self):
|
||||||
|
"""Test __getattr__ method accesses runnable attributes"""
|
||||||
|
mock_runnable = MagicMock()
|
||||||
|
mock_runnable.test_attr = "test_value"
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Should be able to access attributes of the runnable
|
||||||
|
self.assertEqual(wrapper.test_attr, "test_value")
|
||||||
|
|
||||||
|
def test_getattr_attribute_not_exists(self):
|
||||||
|
"""Test __getattr__ method raises AttributeError for non-existent attributes"""
|
||||||
|
|
||||||
|
# Create a simple object without any attributes
|
||||||
|
class EmptyRunnable:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_runnable = EmptyRunnable()
|
||||||
|
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
# Should raise AttributeError for non-existent attributes
|
||||||
|
with self.assertRaises(AttributeError) as context:
|
||||||
|
_ = wrapper.non_existent_attr
|
||||||
|
|
||||||
|
self.assertIn("Attribute non_existent_attr not exists",
|
||||||
|
str(context.exception))
|
||||||
|
|
||||||
|
def test_unwrap_method(self):
|
||||||
|
"""Test unwrap method returns the original runnable"""
|
||||||
|
wrapper = ACLGraphWrapper(
|
||||||
|
runnable=self.mock_runnable,
|
||||||
|
vllm_config=self.mock_vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
graph_pool=self.mock_graph_pool,
|
||||||
|
cudagraph_options=self.mock_cudagraph_options)
|
||||||
|
|
||||||
|
unwrapped = wrapper.unwrap()
|
||||||
|
self.assertEqual(unwrapped, self.mock_runnable)
|
||||||
Reference in New Issue
Block a user