Files
xc-llm-ascend/tests/ut/compilation/test_acl_graph.py
anon189Ty 5c9f4a40c6 [Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.

The change in both disable_padded_drafter_batch is True and False case
include:

1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
    1). Rebuild MTP model with ACLGraphWrapper.
    2). Add common attn metadata when start capture in MTP dummy_run.
    3). Add common attn metadata update in MTP.
    4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.

Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00

744 lines
32 KiB
Python

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, Mock, patch
import torch
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, ForwardContext
from tests.ut.base import TestBase
from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params,
update_mtp_graph_params_workspaces)
class TestACLGraphEntry(TestBase):
def test_aclgraph_entry_initialization(self):
"""Test ACLGraphEntry initialization with default values"""
batch_descriptor = BatchDescriptor(
num_tokens=30,
uniform_decode=False,
)
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
self.assertEqual(entry.batch_descriptor, batch_descriptor)
self.assertIsNone(entry.aclgraph)
self.assertIsNone(entry.output)
self.assertIsNone(entry.input_addresses)
def test_aclgraph_entry_with_values(self):
"""Test ACLGraphEntry initialization with specified values"""
batch_descriptor = BatchDescriptor(
num_tokens=30,
uniform_decode=False,
)
mock_graph = MagicMock()
mock_output = MagicMock()
input_addresses = [12345, 67890]
entry = ACLGraphEntry(batch_descriptor=batch_descriptor,
aclgraph=mock_graph,
output=mock_output,
input_addresses=input_addresses)
self.assertEqual(entry.batch_descriptor, batch_descriptor)
self.assertEqual(entry.aclgraph, mock_graph)
self.assertEqual(entry.output, mock_output)
self.assertEqual(entry.input_addresses, input_addresses)
class TestACLGraphWrapper(TestBase):
def setUp(self):
"""Set up test fixtures"""
super().setUp()
# Mock VllmConfig
self.mock_vllm_config = MagicMock(spec=VllmConfig)
self.mock_vllm_config.compilation_config = MagicMock()
# Mock runnable function
self.mock_runnable = MagicMock(return_value="test_output")
# Mock graph pool
self.mock_graph_pool = MagicMock()
# Mock CUDAGraphOptions
self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions)
self.mock_cudagraph_options.debug_log_enable = False
self.mock_cudagraph_options.gc_disable = False
self.mock_cudagraph_options.weak_ref_output = False
# Mock BatchDescriptor
self.mock_batch_descriptor = BatchDescriptor(
num_tokens=30,
uniform_decode=False,
)
# Mock ForwardContext
self.mock_forward_context = MagicMock(spec=ForwardContext)
self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_initialization_with_default_options(self, mock_envs,
mock_current_platform):
"""Test ACLGraphWrapper initialization with default CUDAGraphOptions"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
wrapper = ACLGraphWrapper(runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool)
self.assertEqual(wrapper.runnable, self.mock_runnable)
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
self.assertFalse(wrapper.is_debugging_mode)
self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions)
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_initialization_with_custom_options(self, mock_envs,
mock_current_platform):
"""Test ACLGraphWrapper initialization with custom CUDAGraphOptions"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
self.assertEqual(wrapper.runnable, self.mock_runnable)
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
self.assertTrue(wrapper.is_debugging_mode)
self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options)
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_initialization_assertion_error(self, mock_envs,
mock_current_platform):
"""Test ACLGraphWrapper initialization raises AssertionError for NONE mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
with self.assertRaises(AssertionError):
ACLGraphWrapper(runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.NONE,
graph_pool=self.mock_graph_pool)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_none_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context):
"""Test __call__ method when runtime mode is NONE"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
result = wrapper("arg1", "arg2")
# Should call the runnable directly without graph capture
self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_mismatched_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context):
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
result = wrapper("arg1", "arg2")
# Should call the runnable directly without graph capture
self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_first_time(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph for the first time"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors to return the same output
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
result = wrapper(test_tensor, "arg2")
# Verify graph capture happened
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
pool=self.mock_graph_pool)
self.mock_runnable.assert_called_once_with(test_tensor, "arg2")
# Verify the entry was created and updated
self.assertIn(self.mock_batch_descriptor,
wrapper.concrete_aclgraph_entries)
entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor]
self.assertEqual(entry.aclgraph, mock_npu_graph)
self.assertEqual(entry.output, "weak_ref_output")
# Verify compilation counter was incremented
self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1)
# Should return the original output (not weak ref)
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_replay_graph(self, mock_weak_ref_tensors,
mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_validate_cudagraph_capturing_enabled,
mock_torch):
"""Test __call__ method replays graph when already captured"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors to return the same output
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# First call to capture the graph
first_result = wrapper(test_tensor, "arg2")
# Verify graph capture happened during first call
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once()
# Reset mock to track second call
self.mock_runnable.reset_mock()
mock_npu_graph.reset_mock()
# Second call should replay the graph
second_result = wrapper(test_tensor, "arg2")
# Verify runnable was called only during capture (not during replay)
self.mock_runnable.assert_not_called()
# Verify graph replay happened
mock_npu_graph.replay.assert_called_once()
# Both calls should return the weak ref output
self.assertEqual(first_result, "test_output") # Original output
self.assertEqual(second_result, "weak_ref_output") # Weak ref output
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_check(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address checking"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Create a mock tensor as the output of runnable
mock_output_tensor = torch.tensor([4, 5, 6])
self.mock_runnable.return_value = mock_output_tensor
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# First call to capture the graph
tensor = torch.tensor([1, 2, 3]) # Create tensor once
_ = wrapper(tensor, "arg2")
# Second call with same tensor addresses should work
_ = wrapper(tensor, "arg2") # Use the same tensor object
# Should not raise AssertionError
self.assertTrue(True)
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_mismatch(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Create a mock tensor as the output of runnable
mock_output_tensor = torch.tensor([4, 5, 6])
self.mock_runnable.return_value = mock_output_tensor
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# First call to capture the graph
tensor1 = torch.tensor([1, 2, 3])
_ = wrapper(tensor1, "arg2")
# Second call with different tensor addresses should raise AssertionError
tensor2 = torch.tensor([4, 5,
6]) # Different values, different address
with self.assertRaises(AssertionError) as context:
wrapper(tensor2, "arg2")
self.assertIn("Input addresses for aclgraphs are different",
str(context.exception))
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
@patch('vllm_ascend.compilation.acl_graph.patch')
def test_call_capture_graph_with_gc_disable(
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
mock_envs, mock_current_platform, mock_get_forward_context,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with gc_disable option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable gc_disable option
self.mock_cudagraph_options.gc_disable = True
# weak_ref_output is not enabled by default
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock patch context manager
mock_exit_stack = MagicMock()
mock_patch.return_value = mock_exit_stack
mock_exit_stack.enter_context = Mock()
# Mock weak_ref_tensors to simulate the actual behavior:
# 1. First call (inside the graph context) should return "inner_output"
# 2. Second call (for entry.output) should return "weak_ref_output"
mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"]
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
result = wrapper(test_tensor, "arg2")
# Verify patch was called to disable gc
self.assertTrue(mock_patch.called)
# Verify graph capture happened
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
pool=self.mock_graph_pool)
# Should return the original output (not weak ref) since weak_ref_output is not enabled
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_with_weak_ref_output(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with weak_ref_output option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable weak_ref_output option
self.mock_cudagraph_options.weak_ref_output = True
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors to simulate the actual behavior:
# 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output"
# 2. Second call (for entry.output) should return "weak_ref_output"
mock_weak_ref_tensors.side_effect = [
"weak_ref_output", "weak_ref_output"
]
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
result = wrapper(test_tensor, "arg2")
# Verify weak_ref_tensors was called twice (once for inner output, once for final output)
self.assertEqual(mock_weak_ref_tensors.call_count, 2)
# Verify graph capture happened
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
pool=self.mock_graph_pool)
# Should return the weak ref output when weak_ref_output option is enabled
self.assertEqual(result, "weak_ref_output")
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.logger')
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
mock_current_platform,
mock_get_forward_context):
"""Test __call__ method captures graph with debug logging enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable debug logging
self.mock_cudagraph_options.debug_log_enable = True
# weak_ref_output is not enabled by default
# Mock torch
with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch:
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Mock weak_ref_tensors
with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors'
) as mock_weak_ref_tensors:
# Mock weak_ref_tensors to simulate the actual behavior:
# 1. First call (inside the graph context) should return "inner_output"
# 2. Second call (for entry.output) should return "weak_ref_output"
mock_weak_ref_tensors.side_effect = [
"inner_output", "weak_ref_output"
]
# Mock validate_cudagraph_capturing_enabled
with patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
):
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
_ = wrapper(test_tensor, "arg2")
# Verify debug log was called
mock_logger.debug.assert_called_once()
def test_getattr_access_runnable_attributes(self):
"""Test __getattr__ method accesses runnable attributes"""
mock_runnable = MagicMock()
mock_runnable.test_attr = "test_value"
wrapper = ACLGraphWrapper(
runnable=mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Should be able to access attributes of the runnable
self.assertEqual(wrapper.test_attr, "test_value")
def test_getattr_attribute_not_exists(self):
"""Test __getattr__ method raises AttributeError for non-existent attributes"""
# Create a simple object without any attributes
class EmptyRunnable:
pass
mock_runnable = EmptyRunnable()
wrapper = ACLGraphWrapper(
runnable=mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
# Should raise AttributeError for non-existent attributes
with self.assertRaises(AttributeError) as context:
_ = wrapper.non_existent_attr
self.assertIn("Attribute non_existent_attr not exists",
str(context.exception))
def test_unwrap_method(self):
"""Test unwrap method returns the original runnable"""
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
graph_pool=self.mock_graph_pool,
cudagraph_options=self.mock_cudagraph_options)
unwrapped = wrapper.unwrap()
self.assertEqual(unwrapped, self.mock_runnable)
class TestMTPGraphParams(TestBase):
def test_set_mtp_graph_params(self):
with patch('vllm_ascend.compilation.acl_graph._mtp_graph_params',
new=None):
set_mtp_graph_params([4])
from vllm_ascend.compilation.acl_graph import _mtp_graph_params
self.assertIsNotNone(_mtp_graph_params)
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
def test_update_mtp_graph_params_workspaces(self, mtp_graph_params_mock):
mtp_graph_params_mock.workspaces = {4: 5}
update_mtp_graph_params_workspaces(4, 6)
self.assertEqual(mtp_graph_params_mock.workspaces[4], 6)
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
graph_params = get_mtp_graph_params()
self.assertIs(mtp_graph_params_mock, graph_params)