Files
xc-llm-ascend/tests/ut/compilation/test_acl_graph.py

887 lines
39 KiB
Python
Raw Permalink Normal View History

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import torch
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, ForwardContext
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
AscendMetadataForDecode)
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAMetadata)
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
set_draft_graph_params, set_graph_params,
update_draft_graph_params_workspaces)
class TestACLGraphEntry(TestBase):
def test_aclgraph_entry_initialization(self):
"""Test ACLGraphEntry initialization with default values"""
batch_descriptor = BatchDescriptor(
num_tokens=30,
uniform=False,
)
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
self.assertEqual(entry.batch_descriptor, batch_descriptor)
self.assertIsNone(entry.aclgraph)
self.assertIsNone(entry.output)
self.assertIsNone(entry.input_addresses)
def test_aclgraph_entry_with_values(self):
"""Test ACLGraphEntry initialization with specified values"""
batch_descriptor = BatchDescriptor(
num_tokens=30,
uniform=False,
)
mock_graph = MagicMock()
mock_output = MagicMock()
input_addresses = [12345, 67890]
entry = ACLGraphEntry(batch_descriptor=batch_descriptor,
aclgraph=mock_graph,
output=mock_output,
input_addresses=input_addresses)
self.assertEqual(entry.batch_descriptor, batch_descriptor)
self.assertEqual(entry.aclgraph, mock_graph)
self.assertEqual(entry.output, mock_output)
self.assertEqual(entry.input_addresses, input_addresses)
class TestACLGraphWrapper(TestBase):
def setUp(self):
"""Set up test fixtures"""
super().setUp()
# Mock VllmConfig
self.mock_vllm_config = MagicMock(spec=VllmConfig)
self.mock_vllm_config.compilation_config = MagicMock()
# Mock runnable function
self.mock_runnable = MagicMock(return_value="test_output")
# Mock graph pool
self.mock_graph_pool = MagicMock()
# Mock CUDAGraphOptions
self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions)
self.mock_cudagraph_options.debug_log_enable = False
self.mock_cudagraph_options.gc_disable = False
self.mock_cudagraph_options.weak_ref_output = False
# Mock BatchDescriptor
self.mock_batch_descriptor = BatchDescriptor(
num_tokens=30,
uniform=False,
)
# Mock ForwardContext
self.mock_forward_context = MagicMock(spec=ForwardContext)
self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_initialization_with_default_options(self, mock_envs,
mock_current_platform):
"""Test ACLGraphWrapper initialization with default CUDAGraphOptions"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
wrapper = ACLGraphWrapper(runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
upgrade to vllm 0.11.2 (#4400) Bump vLLM version to v0.11.2 What's broken and changed by vLLM: 1. structured_output is broken by https://github.com/vllm-project/vllm/pull/26866 2. get_mrope_input_positions is broken by https://github.com/vllm-project/vllm/pull/28399 3. graph mode is broken by https://github.com/vllm-project/vllm/pull/25110 we'll upgrade torch to 2.8 to fix the problem later 4. embedding is broken by https://github.com/vllm-project/vllm/pull/27583 5. `get_attn_backend_cls` and attention backend is broken are broken by https://github.com/vllm-project/vllm/pull/28534 6. spec decode is broken by https://github.com/vllm-project/vllm/pull/28771 7. sp feature is broken by https://github.com/vllm-project/vllm/pull/27126 8. mtp is broken by https://github.com/vllm-project/vllm/pull/27922 9. lora is broken by https://github.com/vllm-project/vllm/pull/21068 10. execute_model is broken by https://github.com/vllm-project/vllm/pull/26866 11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by https://github.com/vllm-project/vllm/pull/28159 12. kv cahe is broken by https://github.com/vllm-project/vllm/pull/27753 13. dp is broken by https://github.com/vllm-project/vllm/pull/25110 What's broken and changed by ourself: 1. qwen vl is broken by https://github.com/vllm-project/vllm/pull/28455 We'll remove model files in the future to avoid this kind of error 2. Engine core is broken by https://github.com/vllm-project/vllm/pull/23691 We'll remove the patch file in the future. 3. Ascend scheduler is broken by https://github.com/vllm-project/vllm/pull/28733 We'll remove ascend scheudler later. 4. qwen3-next is broken by https://github.com/vllm-project/vllm/pull/28083 We'll remove model files in the future to avoid this kind of error 5. qwen vl is broken by https://github.com/vllm-project/vllm/pull/27764. We'll remove model files in the future Known issue: 1. ray doesn't work 2. the accuracy of qwen3-next is not correct 3. qwen3-vl is broken 4. prefix cache+ ascend scheduler + deepseek v2 lite is broken. Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com> Co-authored-by: shen-shanshan <467638484@qq.com> - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com>
2025-11-26 11:48:58 +08:00
runtime_mode=CUDAGraphMode.FULL)
self.assertEqual(wrapper.runnable, self.mock_runnable)
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
self.assertFalse(wrapper.is_debugging_mode)
self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions)
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_initialization_with_custom_options(self, mock_envs,
mock_current_platform):
"""Test ACLGraphWrapper initialization with custom CUDAGraphOptions"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
self.assertEqual(wrapper.runnable, self.mock_runnable)
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
self.assertTrue(wrapper.is_debugging_mode)
self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options)
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_initialization_assertion_error(self, mock_envs,
mock_current_platform):
"""Test ACLGraphWrapper initialization raises AssertionError for NONE mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
with self.assertRaises(AssertionError):
ACLGraphWrapper(runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
upgrade to vllm 0.11.2 (#4400) Bump vLLM version to v0.11.2 What's broken and changed by vLLM: 1. structured_output is broken by https://github.com/vllm-project/vllm/pull/26866 2. get_mrope_input_positions is broken by https://github.com/vllm-project/vllm/pull/28399 3. graph mode is broken by https://github.com/vllm-project/vllm/pull/25110 we'll upgrade torch to 2.8 to fix the problem later 4. embedding is broken by https://github.com/vllm-project/vllm/pull/27583 5. `get_attn_backend_cls` and attention backend is broken are broken by https://github.com/vllm-project/vllm/pull/28534 6. spec decode is broken by https://github.com/vllm-project/vllm/pull/28771 7. sp feature is broken by https://github.com/vllm-project/vllm/pull/27126 8. mtp is broken by https://github.com/vllm-project/vllm/pull/27922 9. lora is broken by https://github.com/vllm-project/vllm/pull/21068 10. execute_model is broken by https://github.com/vllm-project/vllm/pull/26866 11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by https://github.com/vllm-project/vllm/pull/28159 12. kv cahe is broken by https://github.com/vllm-project/vllm/pull/27753 13. dp is broken by https://github.com/vllm-project/vllm/pull/25110 What's broken and changed by ourself: 1. qwen vl is broken by https://github.com/vllm-project/vllm/pull/28455 We'll remove model files in the future to avoid this kind of error 2. Engine core is broken by https://github.com/vllm-project/vllm/pull/23691 We'll remove the patch file in the future. 3. Ascend scheduler is broken by https://github.com/vllm-project/vllm/pull/28733 We'll remove ascend scheudler later. 4. qwen3-next is broken by https://github.com/vllm-project/vllm/pull/28083 We'll remove model files in the future to avoid this kind of error 5. qwen vl is broken by https://github.com/vllm-project/vllm/pull/27764. We'll remove model files in the future Known issue: 1. ray doesn't work 2. the accuracy of qwen3-next is not correct 3. qwen3-vl is broken 4. prefix cache+ ascend scheduler + deepseek v2 lite is broken. Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com> Co-authored-by: shen-shanshan <467638484@qq.com> - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com>
2025-11-26 11:48:58 +08:00
runtime_mode=CUDAGraphMode.NONE)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_none_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context, mock_get_forward_context_2):
"""Test __call__ method when runtime mode is NONE"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
result = wrapper("arg1", "arg2")
# Should call the runnable directly without graph capture
self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_mismatched_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context,
mock_get_forward_context_2):
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
result = wrapper("arg1", "arg2")
# Should call the runnable directly without graph capture
self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_first_time(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph for the first time"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors to return the same output
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
result = wrapper(test_tensor, "arg2")
# Verify graph capture happened
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
pool=self.mock_graph_pool)
self.mock_runnable.assert_called_once_with(test_tensor, "arg2")
# Verify the entry was created and updated
self.assertIn(self.mock_batch_descriptor,
wrapper.concrete_aclgraph_entries)
entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor]
self.assertEqual(entry.aclgraph, mock_npu_graph)
self.assertEqual(entry.output, "weak_ref_output")
# Verify compilation counter was incremented
self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1)
# Should return the original output (not weak ref)
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_replay_graph(self, mock_weak_ref_tensors,
mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled,
mock_torch):
"""Test __call__ method replays graph when already captured"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors to return the same output
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# First call to capture the graph
first_result = wrapper(test_tensor, "arg2")
# Verify graph capture happened during first call
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once()
# Reset mock to track second call
self.mock_runnable.reset_mock()
mock_npu_graph.reset_mock()
# Second call should replay the graph
second_result = wrapper(test_tensor, "arg2")
# Verify runnable was called only during capture (not during replay)
self.mock_runnable.assert_not_called()
# Verify graph replay happened
mock_npu_graph.replay.assert_called_once()
# Both calls should return the weak ref output
self.assertEqual(first_result, "test_output") # Original output
self.assertEqual(second_result, "weak_ref_output") # Weak ref output
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_check(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address checking"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Create a mock tensor as the output of runnable
mock_output_tensor = torch.tensor([4, 5, 6])
self.mock_runnable.return_value = mock_output_tensor
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# First call to capture the graph
tensor = torch.tensor([1, 2, 3]) # Create tensor once
_ = wrapper(tensor, "arg2")
# Second call with same tensor addresses should work
_ = wrapper(tensor, "arg2") # Use the same tensor object
# Should not raise AssertionError
self.assertTrue(True)
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_mismatch(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors
mock_weak_ref_tensors.return_value = "weak_ref_output"
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Create a mock tensor as the output of runnable
mock_output_tensor = torch.tensor([4, 5, 6])
self.mock_runnable.return_value = mock_output_tensor
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# First call to capture the graph
tensor1 = torch.tensor([1, 2, 3])
_ = wrapper(tensor1, "arg2")
# Second call with different tensor addresses should raise AssertionError
tensor2 = torch.tensor([4, 5,
6]) # Different values, different address
with self.assertRaises(AssertionError) as context:
wrapper(tensor2, "arg2")
self.assertIn("Input addresses for aclgraphs are different",
str(context.exception))
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
@patch('vllm_ascend.compilation.acl_graph.patch')
def test_call_capture_graph_with_gc_disable(
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with gc_disable option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable gc_disable option
self.mock_cudagraph_options.gc_disable = True
# weak_ref_output is not enabled by default
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock patch context manager
mock_exit_stack = MagicMock()
mock_patch.return_value = mock_exit_stack
mock_exit_stack.enter_context = Mock()
# Mock weak_ref_tensors to simulate the actual behavior:
# 1. First call (inside the graph context) should return "inner_output"
# 2. Second call (for entry.output) should return "weak_ref_output"
mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"]
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
result = wrapper(test_tensor, "arg2")
# Verify patch was called to disable gc
self.assertTrue(mock_patch.called)
# Verify graph capture happened
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
pool=self.mock_graph_pool)
# Should return the original output (not weak ref) since weak_ref_output is not enabled
self.assertEqual(result, "test_output")
@patch('vllm_ascend.compilation.acl_graph.torch')
@patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_with_weak_ref_output(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with weak_ref_output option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable weak_ref_output option
self.mock_cudagraph_options.weak_ref_output = True
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Mock weak_ref_tensors to simulate the actual behavior:
# 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output"
# 2. Second call (for entry.output) should return "weak_ref_output"
mock_weak_ref_tensors.side_effect = [
"weak_ref_output", "weak_ref_output"
]
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Set up the compilation counter mock
mock_compilation_counter.num_cudagraph_captured = 0
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
result = wrapper(test_tensor, "arg2")
# Verify weak_ref_tensors was called twice (once for inner output, once for final output)
self.assertEqual(mock_weak_ref_tensors.call_count, 2)
# Verify graph capture happened
mock_validate_cudagraph_capturing_enabled.assert_called_once()
mock_torch.npu.NPUGraph.assert_called_once()
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
pool=self.mock_graph_pool)
# Should return the weak ref output when weak_ref_output option is enabled
self.assertEqual(result, "weak_ref_output")
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.logger')
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
mock_current_platform,
mock_get_forward_context,mock_get_forward_context_2):
"""Test __call__ method captures graph with debug logging enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable debug logging
self.mock_cudagraph_options.debug_log_enable = True
# weak_ref_output is not enabled by default
# Mock torch
with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch:
# Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock()
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
# Mock torch.npu.graph context manager
mock_graph_context = MagicMock()
mock_torch.npu.graph.return_value = mock_graph_context
mock_graph_context.__enter__ = Mock(return_value=None)
mock_graph_context.__exit__ = Mock(return_value=None)
# Ensure torch.Tensor can be correctly identified by isinstance
mock_torch.Tensor = torch.Tensor
# Mock weak_ref_tensors
with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors'
) as mock_weak_ref_tensors:
# Mock weak_ref_tensors to simulate the actual behavior:
# 1. First call (inside the graph context) should return "inner_output"
# 2. Second call (for entry.output) should return "weak_ref_output"
mock_weak_ref_tensors.side_effect = [
"inner_output", "weak_ref_output"
]
# Mock validate_cudagraph_capturing_enabled
with patch(
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
):
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Create a real torch tensor for the test, not a mock
test_tensor = torch.tensor([1, 2, 3])
# Call the wrapper
_ = wrapper(test_tensor, "arg2")
# Verify debug log was called
mock_logger.debug.assert_called_once()
def test_getattr_access_runnable_attributes(self):
"""Test __getattr__ method accesses runnable attributes"""
mock_runnable = MagicMock()
mock_runnable.test_attr = "test_value"
wrapper = ACLGraphWrapper(
runnable=mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Should be able to access attributes of the runnable
self.assertEqual(wrapper.test_attr, "test_value")
def test_getattr_attribute_not_exists(self):
"""Test __getattr__ method raises AttributeError for non-existent attributes"""
# Create a simple object without any attributes
class EmptyRunnable:
pass
mock_runnable = EmptyRunnable()
wrapper = ACLGraphWrapper(
runnable=mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
# Should raise AttributeError for non-existent attributes
with self.assertRaises(AttributeError) as context:
_ = wrapper.non_existent_attr
[Bugfix] Fix slow hasattr in ACLGraphWrapper.__getattr__ (#7442) ### What this PR does / why we need it? Follow https://github.com/vllm-project/vllm/pull/37425, https://github.com/vllm-project/vllm-omni/pull/1982 Copied from them: Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per decode step when profiling Qwen3 Omni. The original `CUDAGraphWrapper.__getattr__` raises: ```python raise AttributeError(f"... cudagraph wrapper: {self.runnable}") ``` When hasattr() is called for a non-existent attribute, Python internally calls __getattr__ which constructs this AttributeError. The {self.runnable} triggers `__repr__()` on the underlying model (e.g., `Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the entire nn.Module tree to generate an 18,000+ character string. This takes ~6-7ms per call. Since `hasattr(self.model, "flush_pending_metadata") ` is called every decode step in the Talker forward path, this adds ~6ms overhead per step, severely impacting audio inter-chunk latency (ICL). ```Python hasattr(self.model, "flush_pending_metadata") → getattr(self.model, "flush_pending_metadata") → not found in CUDAGraphWrapper.__dict__ → not found in the CUDAGraphWrapper class hierarchy → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata") → hasattr(self.runnable, "flush_pending_metadata") # runnable also doesn't have it → executes raise AttributeError(f"... {self.runnable}") → Python needs to construct the exception object → the f-string triggers self.runnable.__repr__() → Qwen3OmniMoeForConditionalGeneration.__repr__() → recursively traverses the entire nn.Module tree → generates a 18,000+ character string → takes ~6 ms → AttributeError object is created → hasattr catches the AttributeError and returns False → the 18,000-character string is immediately discarded (no one ever sees it) ``` ### Does this PR introduce _any_ user-facing change? NO. ### How was this patch tested? See https://github.com/vllm-project/vllm-omni/pull/1982 - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87 --------- Signed-off-by: gcanlin <canlinguosdu@gmail.com>
2026-03-23 09:26:24 +08:00
self.assertIn("Attribute non_existent_attr not found",
str(context.exception))
def test_unwrap_method(self):
"""Test unwrap method returns the original runnable"""
wrapper = ACLGraphWrapper(
runnable=self.mock_runnable,
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.FULL,
cudagraph_options=self.mock_cudagraph_options)
unwrapped = wrapper.unwrap()
self.assertEqual(unwrapped, self.mock_runnable)
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
class TestDraftGraphParams(TestBase):
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
def test_set_draft_graph_params(self):
with patch('vllm_ascend.compilation.acl_graph._draft_graph_params',
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
new=None):
set_draft_graph_params([4])
from vllm_ascend.compilation.acl_graph import _draft_graph_params
self.assertIsNotNone(_draft_graph_params)
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
@patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
def test_update_draft_graph_params_workspaces(self,
draft_graph_params_mock):
draft_graph_params_mock.workspaces = {4: 5}
update_draft_graph_params_workspaces(4, 6)
self.assertEqual(draft_graph_params_mock.workspaces[4], 6)
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
@patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
def test_get_draft_graph_params(self, draft_graph_params_mock):
graph_params = get_draft_graph_params()
self.assertIs(draft_graph_params_mock, graph_params)
class TestPCPDCPGraphParams(TestBase):
def setUp(self):
self.update_stream = MagicMock(name="FakeStream")
graph_params = get_graph_params()
if graph_params is None:
set_graph_params(set([4]))
self.graph_params = get_graph_params()
else:
self.graph_params = graph_params
mock_event = torch.npu.ExternalEvent()
mock_event.record = MagicMock()
self.graph_params.events[4] = []
self.graph_params.handles[4] = []
self.graph_params.events[4].append(mock_event)
self.graph_params.handles[4].append(MagicMock())
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
seq_lens = torch.tensor([4, 4])
cp_seq_len = torch.tensor([2, 2])
max_seq_lens = 4
seq_lens_list = [4, 4]
slot_mapping = torch.zeros(8, dtype=torch.long)
query_start_loc = torch.tensor([0, 4])
block_tables = torch.zeros(2, 5, dtype=torch.long)
decode = AscendMLADecodeMetadata(input_positions,
block_table,
seq_lens,
max_seq_lens,
seq_lens_list,
cp_seq_len=cp_seq_len)
metadata = AscendMLAMetadata(8,
8,
slot_mapping,
query_start_loc,
seq_lens,
seq_lens,
block_tables,
4,
4,
0,
decode=decode)
forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_draft_model = False
mock_context.return_value = forward_context
num_heads = 256
scale = 0.1
num_kv_heads = 8
qk_head_dim = 96
qk_rope_head_dim = 32
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_nope = query[..., :qk_nope_head_dim]
q_pe = query[..., qk_rope_head_dim:]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
input_layout = "BNSD"
actual_seq_lengths_kv = [1, 1]
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
out, lse))
with patch("torch_npu._C._npu_setStream", return_value=None):
AscendMlaCPImpl.update_graph_params(
self.update_stream, forward_context, 4
)
_mock_graph_task_end.assert_called_once()
@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
block_table = torch.zeros(2, 5, dtype=torch.long)
num_heads = 256
scale = 0.1
num_kv_heads = 8
qk_head_dim = 96
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_nope = query[..., :qk_nope_head_dim]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
actual_seq_lengths_kv = [1, 1]
actual_seq_lengths_q = np.array([1, 1])
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]],
[[1, 1], [1, 1]]])
decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp)
metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1],
actual_seq_lengths_q=actual_seq_lengths_q,
num_decode_tokens=1,
decode_meta=decode)
forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_draft_model = False
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale,
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
out, lse, 2, 0, 0))
with patch("torch_npu._C._npu_setStream", return_value=None):
AscendAttentionCPImpl.update_graph_params(
self.update_stream, forward_context, 4, None
)
_mock_graph_task_end.assert_called_once()