2025-09-19 21:31:17 +08:00
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
|
|
2025-12-09 17:27:40 +08:00
|
|
|
import numpy as np
|
2025-09-19 21:31:17 +08:00
|
|
|
import torch
|
|
|
|
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
|
|
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
|
|
|
from vllm.forward_context import BatchDescriptor, ForwardContext
|
|
|
|
|
|
|
|
|
|
from tests.ut.base import TestBase
|
2025-12-09 17:27:40 +08:00
|
|
|
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
|
|
|
|
AscendMetadataForDecode)
|
2026-01-26 09:04:54 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
|
|
|
|
AscendAttentionCPImpl
|
|
|
|
|
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
2025-12-09 17:27:40 +08:00
|
|
|
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
|
|
|
|
AscendMLAMetadata)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
from vllm_ascend.compilation.acl_graph import (
|
2025-12-29 09:54:51 +08:00
|
|
|
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
|
2026-01-26 09:04:54 +08:00
|
|
|
set_draft_graph_params, set_graph_params,
|
|
|
|
|
update_draft_graph_params_workspaces)
|
2025-09-19 21:31:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestACLGraphEntry(TestBase):
|
|
|
|
|
|
|
|
|
|
def test_aclgraph_entry_initialization(self):
|
|
|
|
|
"""Test ACLGraphEntry initialization with default values"""
|
|
|
|
|
batch_descriptor = BatchDescriptor(
|
|
|
|
|
num_tokens=30,
|
2025-12-02 22:10:52 +08:00
|
|
|
uniform=False,
|
2025-09-19 21:31:17 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
|
|
|
|
self.assertIsNone(entry.aclgraph)
|
|
|
|
|
self.assertIsNone(entry.output)
|
|
|
|
|
self.assertIsNone(entry.input_addresses)
|
|
|
|
|
|
|
|
|
|
def test_aclgraph_entry_with_values(self):
|
|
|
|
|
"""Test ACLGraphEntry initialization with specified values"""
|
|
|
|
|
batch_descriptor = BatchDescriptor(
|
|
|
|
|
num_tokens=30,
|
2025-12-02 22:10:52 +08:00
|
|
|
uniform=False,
|
2025-09-19 21:31:17 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
mock_graph = MagicMock()
|
|
|
|
|
mock_output = MagicMock()
|
|
|
|
|
input_addresses = [12345, 67890]
|
|
|
|
|
|
|
|
|
|
entry = ACLGraphEntry(batch_descriptor=batch_descriptor,
|
|
|
|
|
aclgraph=mock_graph,
|
|
|
|
|
output=mock_output,
|
|
|
|
|
input_addresses=input_addresses)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
|
|
|
|
self.assertEqual(entry.aclgraph, mock_graph)
|
|
|
|
|
self.assertEqual(entry.output, mock_output)
|
|
|
|
|
self.assertEqual(entry.input_addresses, input_addresses)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestACLGraphWrapper(TestBase):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
"""Set up test fixtures"""
|
|
|
|
|
super().setUp()
|
|
|
|
|
|
|
|
|
|
# Mock VllmConfig
|
|
|
|
|
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
|
|
|
self.mock_vllm_config.compilation_config = MagicMock()
|
|
|
|
|
|
|
|
|
|
# Mock runnable function
|
|
|
|
|
self.mock_runnable = MagicMock(return_value="test_output")
|
|
|
|
|
|
|
|
|
|
# Mock graph pool
|
|
|
|
|
self.mock_graph_pool = MagicMock()
|
|
|
|
|
|
|
|
|
|
# Mock CUDAGraphOptions
|
|
|
|
|
self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions)
|
|
|
|
|
self.mock_cudagraph_options.debug_log_enable = False
|
|
|
|
|
self.mock_cudagraph_options.gc_disable = False
|
|
|
|
|
self.mock_cudagraph_options.weak_ref_output = False
|
|
|
|
|
|
|
|
|
|
# Mock BatchDescriptor
|
|
|
|
|
self.mock_batch_descriptor = BatchDescriptor(
|
|
|
|
|
num_tokens=30,
|
2025-12-02 22:10:52 +08:00
|
|
|
uniform=False,
|
2025-09-19 21:31:17 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Mock ForwardContext
|
|
|
|
|
self.mock_forward_context = MagicMock(spec=ForwardContext)
|
|
|
|
|
self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor
|
|
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
def test_initialization_with_default_options(self, mock_envs,
|
|
|
|
|
mock_current_platform):
|
|
|
|
|
"""Test ACLGraphWrapper initialization with default CUDAGraphOptions"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
2025-11-26 11:48:58 +08:00
|
|
|
runtime_mode=CUDAGraphMode.FULL)
|
2025-09-19 21:31:17 +08:00
|
|
|
|
|
|
|
|
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
|
|
|
|
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
|
|
|
|
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
|
|
|
|
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
|
|
|
|
self.assertFalse(wrapper.is_debugging_mode)
|
|
|
|
|
self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions)
|
|
|
|
|
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
def test_initialization_with_custom_options(self, mock_envs,
|
|
|
|
|
mock_current_platform):
|
|
|
|
|
"""Test ACLGraphWrapper initialization with custom CUDAGraphOptions"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
|
|
|
|
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
|
|
|
|
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
|
|
|
|
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
|
|
|
|
self.assertTrue(wrapper.is_debugging_mode)
|
|
|
|
|
self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options)
|
|
|
|
|
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
def test_initialization_assertion_error(self, mock_envs,
|
|
|
|
|
mock_current_platform):
|
|
|
|
|
"""Test ACLGraphWrapper initialization raises AssertionError for NONE mode"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
|
ACLGraphWrapper(runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
2025-11-26 11:48:58 +08:00
|
|
|
runtime_mode=CUDAGraphMode.NONE)
|
2025-09-19 21:31:17 +08:00
|
|
|
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
def test_call_with_none_runtime_mode(self, mock_envs,
|
|
|
|
|
mock_current_platform,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context, mock_get_forward_context_2):
|
2025-09-19 21:31:17 +08:00
|
|
|
"""Test __call__ method when runtime mode is NONE"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
|
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
result = wrapper("arg1", "arg2")
|
|
|
|
|
|
|
|
|
|
# Should call the runnable directly without graph capture
|
|
|
|
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
|
|
|
|
self.assertEqual(result, "test_output")
|
|
|
|
|
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
|
|
|
|
mock_current_platform,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context,
|
|
|
|
|
mock_get_forward_context_2):
|
2025-09-19 21:31:17 +08:00
|
|
|
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
result = wrapper("arg1", "arg2")
|
|
|
|
|
|
|
|
|
|
# Should call the runnable directly without graph capture
|
|
|
|
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
|
|
|
|
self.assertEqual(result, "test_output")
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
|
|
|
@patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
)
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
|
|
|
def test_call_capture_graph_first_time(
|
|
|
|
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
2025-09-19 21:31:17 +08:00
|
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
|
|
|
"""Test __call__ method captures graph for the first time"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors to return the same output
|
|
|
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Set up the compilation counter mock
|
|
|
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
|
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
|
|
|
|
|
|
# Call the wrapper
|
|
|
|
|
result = wrapper(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify graph capture happened
|
|
|
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
|
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
|
|
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
|
|
|
|
pool=self.mock_graph_pool)
|
|
|
|
|
self.mock_runnable.assert_called_once_with(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify the entry was created and updated
|
|
|
|
|
self.assertIn(self.mock_batch_descriptor,
|
|
|
|
|
wrapper.concrete_aclgraph_entries)
|
|
|
|
|
entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor]
|
|
|
|
|
self.assertEqual(entry.aclgraph, mock_npu_graph)
|
|
|
|
|
self.assertEqual(entry.output, "weak_ref_output")
|
|
|
|
|
|
|
|
|
|
# Verify compilation counter was incremented
|
|
|
|
|
self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1)
|
|
|
|
|
|
|
|
|
|
# Should return the original output (not weak ref)
|
|
|
|
|
self.assertEqual(result, "test_output")
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
|
|
|
@patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
)
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
|
|
|
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
|
|
|
|
mock_compilation_counter, mock_envs,
|
|
|
|
|
mock_current_platform, mock_get_forward_context,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2,
|
2025-09-19 21:31:17 +08:00
|
|
|
mock_validate_cudagraph_capturing_enabled,
|
|
|
|
|
mock_torch):
|
|
|
|
|
"""Test __call__ method replays graph when already captured"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
|
|
|
|
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
[Feat] Merge the multi eagle graphs to one graph (#5940)
### What this PR does / why we need it?
This PR merge all steps of draft model in fullgraph mode, to avoid the
synchronize between each graph, reduce the bubble time.
#### Key ideas:
- The "model forward" of the step 0 (first step) and remaining steps are
captured together as a "Callable", rather than capturing each model
individually.
- "update_attn_params" is moved outside the entire graph, meaning that
all "attn_metadata" required by all steps are constructed before
"replay", and the "attn_params" of all steps are updated at once.
- Remove synchronization between the main model graph and draft model
graph.
#### Key params/functions:
- params: draft_attn_metadatas, attn_metadata_multi_steps,
slot_mapping_group
- functions: _run_merged_draft, attn_update_stack_num_spec_norm,
update_attn_params, _propose, dummy_run
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2026-01-23 08:37:02 +08:00
|
|
|
self.mock_forward_context.is_draft_model = False
|
2025-09-19 21:31:17 +08:00
|
|
|
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors to return the same output
|
|
|
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Set up the compilation counter mock
|
|
|
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
|
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
|
|
|
|
|
|
# First call to capture the graph
|
|
|
|
|
first_result = wrapper(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify graph capture happened during first call
|
|
|
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
|
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
|
|
|
mock_torch.npu.graph.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# Reset mock to track second call
|
|
|
|
|
self.mock_runnable.reset_mock()
|
|
|
|
|
mock_npu_graph.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Second call should replay the graph
|
|
|
|
|
second_result = wrapper(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify runnable was called only during capture (not during replay)
|
|
|
|
|
self.mock_runnable.assert_not_called()
|
|
|
|
|
|
|
|
|
|
# Verify graph replay happened
|
|
|
|
|
mock_npu_graph.replay.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# Both calls should return the weak ref output
|
|
|
|
|
self.assertEqual(first_result, "test_output") # Original output
|
|
|
|
|
self.assertEqual(second_result, "weak_ref_output") # Weak ref output
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
|
|
|
@patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
)
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
|
|
|
def test_call_with_debug_mode_input_address_check(
|
|
|
|
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context,mock_get_forward_context_2,
|
2025-09-19 21:31:17 +08:00
|
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
|
|
|
"""Test __call__ method with debug mode input address checking"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
[Feat] Merge the multi eagle graphs to one graph (#5940)
### What this PR does / why we need it?
This PR merge all steps of draft model in fullgraph mode, to avoid the
synchronize between each graph, reduce the bubble time.
#### Key ideas:
- The "model forward" of the step 0 (first step) and remaining steps are
captured together as a "Callable", rather than capturing each model
individually.
- "update_attn_params" is moved outside the entire graph, meaning that
all "attn_metadata" required by all steps are constructed before
"replay", and the "attn_params" of all steps are updated at once.
- Remove synchronization between the main model graph and draft model
graph.
#### Key params/functions:
- params: draft_attn_metadatas, attn_metadata_multi_steps,
slot_mapping_group
- functions: _run_merged_draft, attn_update_stack_num_spec_norm,
update_attn_params, _propose, dummy_run
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2026-01-23 08:37:02 +08:00
|
|
|
self.mock_forward_context.is_draft_model = False
|
2025-09-19 21:31:17 +08:00
|
|
|
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors
|
|
|
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Create a mock tensor as the output of runnable
|
|
|
|
|
mock_output_tensor = torch.tensor([4, 5, 6])
|
|
|
|
|
self.mock_runnable.return_value = mock_output_tensor
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# First call to capture the graph
|
|
|
|
|
tensor = torch.tensor([1, 2, 3]) # Create tensor once
|
|
|
|
|
_ = wrapper(tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Second call with same tensor addresses should work
|
|
|
|
|
_ = wrapper(tensor, "arg2") # Use the same tensor object
|
|
|
|
|
|
|
|
|
|
# Should not raise AssertionError
|
|
|
|
|
self.assertTrue(True)
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
|
|
|
@patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
)
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
|
|
|
def test_call_with_debug_mode_input_address_mismatch(
|
|
|
|
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context,mock_get_forward_context_2,
|
2025-09-19 21:31:17 +08:00
|
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
|
|
|
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors
|
|
|
|
|
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Create a mock tensor as the output of runnable
|
|
|
|
|
mock_output_tensor = torch.tensor([4, 5, 6])
|
|
|
|
|
self.mock_runnable.return_value = mock_output_tensor
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# First call to capture the graph
|
|
|
|
|
tensor1 = torch.tensor([1, 2, 3])
|
|
|
|
|
_ = wrapper(tensor1, "arg2")
|
|
|
|
|
|
|
|
|
|
# Second call with different tensor addresses should raise AssertionError
|
|
|
|
|
tensor2 = torch.tensor([4, 5,
|
|
|
|
|
6]) # Different values, different address
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(AssertionError) as context:
|
|
|
|
|
wrapper(tensor2, "arg2")
|
|
|
|
|
|
|
|
|
|
self.assertIn("Input addresses for aclgraphs are different",
|
|
|
|
|
str(context.exception))
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
|
|
|
@patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
)
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.patch')
|
|
|
|
|
def test_call_capture_graph_with_gc_disable(
|
|
|
|
|
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
2025-09-19 21:31:17 +08:00
|
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
|
|
|
"""Test __call__ method captures graph with gc_disable option enabled"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
|
|
|
|
|
|
# Enable gc_disable option
|
|
|
|
|
self.mock_cudagraph_options.gc_disable = True
|
|
|
|
|
# weak_ref_output is not enabled by default
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Mock patch context manager
|
|
|
|
|
mock_exit_stack = MagicMock()
|
|
|
|
|
mock_patch.return_value = mock_exit_stack
|
|
|
|
|
mock_exit_stack.enter_context = Mock()
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
|
|
|
|
# 1. First call (inside the graph context) should return "inner_output"
|
|
|
|
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
|
|
|
|
mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"]
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Set up the compilation counter mock
|
|
|
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
|
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
|
|
|
|
|
|
# Call the wrapper
|
|
|
|
|
result = wrapper(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify patch was called to disable gc
|
|
|
|
|
self.assertTrue(mock_patch.called)
|
|
|
|
|
|
|
|
|
|
# Verify graph capture happened
|
|
|
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
|
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
|
|
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
|
|
|
|
pool=self.mock_graph_pool)
|
|
|
|
|
|
|
|
|
|
# Should return the original output (not weak ref) since weak_ref_output is not enabled
|
|
|
|
|
self.assertEqual(result, "test_output")
|
|
|
|
|
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.torch')
|
|
|
|
|
@patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
)
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
|
|
|
|
def test_call_capture_graph_with_weak_ref_output(
|
|
|
|
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
2025-09-19 21:31:17 +08:00
|
|
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
|
|
|
|
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
|
|
|
|
|
|
# Enable weak_ref_output option
|
|
|
|
|
self.mock_cudagraph_options.weak_ref_output = True
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
|
|
|
|
# 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output"
|
|
|
|
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
|
|
|
|
mock_weak_ref_tensors.side_effect = [
|
|
|
|
|
"weak_ref_output", "weak_ref_output"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Set up the compilation counter mock
|
|
|
|
|
mock_compilation_counter.num_cudagraph_captured = 0
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
|
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
|
|
|
|
|
|
# Call the wrapper
|
|
|
|
|
result = wrapper(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify weak_ref_tensors was called twice (once for inner output, once for final output)
|
|
|
|
|
self.assertEqual(mock_weak_ref_tensors.call_count, 2)
|
|
|
|
|
|
|
|
|
|
# Verify graph capture happened
|
|
|
|
|
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
|
|
|
|
mock_torch.npu.NPUGraph.assert_called_once()
|
|
|
|
|
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
|
|
|
|
pool=self.mock_graph_pool)
|
|
|
|
|
|
|
|
|
|
# Should return the weak ref output when weak_ref_output option is enabled
|
|
|
|
|
self.assertEqual(result, "weak_ref_output")
|
2026-03-13 09:11:46 +08:00
|
|
|
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-09-19 21:31:17 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
|
|
|
|
@patch('vllm_ascend.compilation.acl_graph.logger')
|
|
|
|
|
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
|
|
|
|
mock_current_platform,
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context,mock_get_forward_context_2):
|
2025-09-19 21:31:17 +08:00
|
|
|
"""Test __call__ method captures graph with debug logging enabled"""
|
|
|
|
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
|
|
|
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
|
|
|
|
mock_get_forward_context.return_value = self.mock_forward_context
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
2025-09-19 21:31:17 +08:00
|
|
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
|
|
|
|
|
|
|
|
# Enable debug logging
|
|
|
|
|
self.mock_cudagraph_options.debug_log_enable = True
|
|
|
|
|
# weak_ref_output is not enabled by default
|
|
|
|
|
|
|
|
|
|
# Mock torch
|
|
|
|
|
with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch:
|
|
|
|
|
# Mock torch.npu.NPUGraph
|
|
|
|
|
mock_npu_graph = MagicMock()
|
|
|
|
|
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
|
|
|
|
|
|
|
|
|
# Mock torch.npu.graph context manager
|
|
|
|
|
mock_graph_context = MagicMock()
|
|
|
|
|
mock_torch.npu.graph.return_value = mock_graph_context
|
|
|
|
|
mock_graph_context.__enter__ = Mock(return_value=None)
|
|
|
|
|
mock_graph_context.__exit__ = Mock(return_value=None)
|
|
|
|
|
|
|
|
|
|
# Ensure torch.Tensor can be correctly identified by isinstance
|
|
|
|
|
mock_torch.Tensor = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Mock weak_ref_tensors
|
|
|
|
|
with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors'
|
|
|
|
|
) as mock_weak_ref_tensors:
|
|
|
|
|
# Mock weak_ref_tensors to simulate the actual behavior:
|
|
|
|
|
# 1. First call (inside the graph context) should return "inner_output"
|
|
|
|
|
# 2. Second call (for entry.output) should return "weak_ref_output"
|
|
|
|
|
mock_weak_ref_tensors.side_effect = [
|
|
|
|
|
"inner_output", "weak_ref_output"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Mock validate_cudagraph_capturing_enabled
|
|
|
|
|
with patch(
|
|
|
|
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
|
|
|
|
):
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Create a real torch tensor for the test, not a mock
|
|
|
|
|
test_tensor = torch.tensor([1, 2, 3])
|
|
|
|
|
|
|
|
|
|
# Call the wrapper
|
|
|
|
|
_ = wrapper(test_tensor, "arg2")
|
|
|
|
|
|
|
|
|
|
# Verify debug log was called
|
|
|
|
|
mock_logger.debug.assert_called_once()
|
|
|
|
|
|
|
|
|
|
def test_getattr_access_runnable_attributes(self):
|
|
|
|
|
"""Test __getattr__ method accesses runnable attributes"""
|
|
|
|
|
mock_runnable = MagicMock()
|
|
|
|
|
mock_runnable.test_attr = "test_value"
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Should be able to access attributes of the runnable
|
|
|
|
|
self.assertEqual(wrapper.test_attr, "test_value")
|
|
|
|
|
|
|
|
|
|
def test_getattr_attribute_not_exists(self):
|
|
|
|
|
"""Test __getattr__ method raises AttributeError for non-existent attributes"""
|
|
|
|
|
|
|
|
|
|
# Create a simple object without any attributes
|
|
|
|
|
class EmptyRunnable:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
mock_runnable = EmptyRunnable()
|
|
|
|
|
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
# Should raise AttributeError for non-existent attributes
|
|
|
|
|
with self.assertRaises(AttributeError) as context:
|
|
|
|
|
_ = wrapper.non_existent_attr
|
|
|
|
|
|
[Bugfix] Fix slow hasattr in ACLGraphWrapper.__getattr__ (#7442)
### What this PR does / why we need it?
Follow https://github.com/vllm-project/vllm/pull/37425,
https://github.com/vllm-project/vllm-omni/pull/1982
Copied from them:
Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.
The original `CUDAGraphWrapper.__getattr__` raises:
```python
raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).
```Python
hasattr(self.model, "flush_pending_metadata")
→ getattr(self.model, "flush_pending_metadata")
→ not found in CUDAGraphWrapper.__dict__
→ not found in the CUDAGraphWrapper class hierarchy
→ triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
→ hasattr(self.runnable, "flush_pending_metadata") # runnable also doesn't have it
→ executes raise AttributeError(f"... {self.runnable}")
→ Python needs to construct the exception object
→ the f-string triggers self.runnable.__repr__()
→ Qwen3OmniMoeForConditionalGeneration.__repr__()
→ recursively traverses the entire nn.Module tree
→ generates a 18,000+ character string
→ takes ~6 ms
→ AttributeError object is created
→ hasattr catches the AttributeError and returns False
→ the 18,000-character string is immediately discarded (no one ever sees it)
```
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
See https://github.com/vllm-project/vllm-omni/pull/1982
- vLLM version: v0.17.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87
---------
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
2026-03-23 09:26:24 +08:00
|
|
|
self.assertIn("Attribute non_existent_attr not found",
|
2025-09-19 21:31:17 +08:00
|
|
|
str(context.exception))
|
|
|
|
|
|
|
|
|
|
def test_unwrap_method(self):
|
|
|
|
|
"""Test unwrap method returns the original runnable"""
|
|
|
|
|
wrapper = ACLGraphWrapper(
|
|
|
|
|
runnable=self.mock_runnable,
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
runtime_mode=CUDAGraphMode.FULL,
|
|
|
|
|
cudagraph_options=self.mock_cudagraph_options)
|
|
|
|
|
|
|
|
|
|
unwrapped = wrapper.unwrap()
|
|
|
|
|
self.assertEqual(unwrapped, self.mock_runnable)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
|
|
|
|
|
2025-12-29 09:54:51 +08:00
|
|
|
class TestDraftGraphParams(TestBase):
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
2025-12-29 09:54:51 +08:00
|
|
|
def test_set_draft_graph_params(self):
|
|
|
|
|
with patch('vllm_ascend.compilation.acl_graph._draft_graph_params',
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
new=None):
|
2025-12-29 09:54:51 +08:00
|
|
|
set_draft_graph_params([4])
|
|
|
|
|
from vllm_ascend.compilation.acl_graph import _draft_graph_params
|
|
|
|
|
self.assertIsNotNone(_draft_graph_params)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
2025-12-29 09:54:51 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
|
|
|
|
|
def test_update_draft_graph_params_workspaces(self,
|
|
|
|
|
draft_graph_params_mock):
|
|
|
|
|
draft_graph_params_mock.workspaces = {4: 5}
|
|
|
|
|
update_draft_graph_params_workspaces(4, 6)
|
|
|
|
|
self.assertEqual(draft_graph_params_mock.workspaces[4], 6)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
2025-12-29 09:54:51 +08:00
|
|
|
@patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
|
|
|
|
|
def test_get_draft_graph_params(self, draft_graph_params_mock):
|
|
|
|
|
graph_params = get_draft_graph_params()
|
|
|
|
|
self.assertIs(draft_graph_params_mock, graph_params)
|
2025-12-09 17:27:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPCPDCPGraphParams(TestBase):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.update_stream = MagicMock(name="FakeStream")
|
|
|
|
|
graph_params = get_graph_params()
|
|
|
|
|
if graph_params is None:
|
|
|
|
|
set_graph_params(set([4]))
|
|
|
|
|
self.graph_params = get_graph_params()
|
|
|
|
|
else:
|
|
|
|
|
self.graph_params = graph_params
|
|
|
|
|
mock_event = torch.npu.ExternalEvent()
|
|
|
|
|
mock_event.record = MagicMock()
|
|
|
|
|
self.graph_params.events[4] = []
|
|
|
|
|
self.graph_params.handles[4] = []
|
|
|
|
|
self.graph_params.events[4].append(mock_event)
|
|
|
|
|
self.graph_params.handles[4].append(MagicMock())
|
|
|
|
|
|
2026-03-13 09:11:46 +08:00
|
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
2025-12-09 17:27:40 +08:00
|
|
|
@patch('torch.npu.graph_task_update_end', )
|
|
|
|
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
2026-01-22 20:02:30 +08:00
|
|
|
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
2026-03-13 09:11:46 +08:00
|
|
|
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
|
2025-12-09 17:27:40 +08:00
|
|
|
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
|
|
|
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
|
|
|
|
seq_lens = torch.tensor([4, 4])
|
|
|
|
|
cp_seq_len = torch.tensor([2, 2])
|
|
|
|
|
max_seq_lens = 4
|
|
|
|
|
seq_lens_list = [4, 4]
|
|
|
|
|
slot_mapping = torch.zeros(8, dtype=torch.long)
|
|
|
|
|
query_start_loc = torch.tensor([0, 4])
|
|
|
|
|
block_tables = torch.zeros(2, 5, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
decode = AscendMLADecodeMetadata(input_positions,
|
|
|
|
|
block_table,
|
|
|
|
|
seq_lens,
|
|
|
|
|
max_seq_lens,
|
|
|
|
|
seq_lens_list,
|
|
|
|
|
cp_seq_len=cp_seq_len)
|
|
|
|
|
metadata = AscendMLAMetadata(8,
|
|
|
|
|
8,
|
|
|
|
|
slot_mapping,
|
|
|
|
|
query_start_loc,
|
|
|
|
|
seq_lens,
|
2026-03-27 14:24:53 +08:00
|
|
|
seq_lens,
|
2025-12-09 17:27:40 +08:00
|
|
|
block_tables,
|
|
|
|
|
4,
|
|
|
|
|
4,
|
|
|
|
|
0,
|
|
|
|
|
decode=decode)
|
|
|
|
|
forward_context = MagicMock()
|
|
|
|
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
2025-12-29 09:54:51 +08:00
|
|
|
forward_context.is_draft_model = False
|
2026-03-13 09:11:46 +08:00
|
|
|
mock_context.return_value = forward_context
|
2025-12-09 17:27:40 +08:00
|
|
|
|
|
|
|
|
num_heads = 256
|
|
|
|
|
scale = 0.1
|
|
|
|
|
num_kv_heads = 8
|
|
|
|
|
qk_head_dim = 96
|
|
|
|
|
qk_rope_head_dim = 32
|
|
|
|
|
qk_nope_head_dim = 64
|
|
|
|
|
query = torch.randn(4, num_heads, qk_head_dim)
|
2026-01-22 20:02:30 +08:00
|
|
|
|
2025-12-09 17:27:40 +08:00
|
|
|
q_nope = query[..., :qk_nope_head_dim]
|
2026-01-22 20:02:30 +08:00
|
|
|
q_pe = query[..., qk_rope_head_dim:]
|
2025-12-09 17:27:40 +08:00
|
|
|
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
|
|
|
|
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
|
2026-01-22 20:02:30 +08:00
|
|
|
input_layout = "BNSD"
|
|
|
|
|
actual_seq_lengths_kv = [1, 1]
|
2025-12-09 17:27:40 +08:00
|
|
|
out = torch.randn(2, 16, 128)
|
|
|
|
|
lse = torch.randn(2, 16, 8)
|
|
|
|
|
self.graph_params.attn_params[4] = []
|
|
|
|
|
self.graph_params.attn_params[4].append(
|
2026-01-22 20:02:30 +08:00
|
|
|
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
|
|
|
|
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
|
|
|
|
|
out, lse))
|
2025-12-09 17:27:40 +08:00
|
|
|
|
2025-12-15 19:54:23 +08:00
|
|
|
with patch("torch_npu._C._npu_setStream", return_value=None):
|
2026-01-26 09:04:54 +08:00
|
|
|
AscendMlaCPImpl.update_graph_params(
|
|
|
|
|
self.update_stream, forward_context, 4
|
|
|
|
|
)
|
2025-12-09 17:27:40 +08:00
|
|
|
|
|
|
|
|
_mock_graph_task_end.assert_called_once()
|
|
|
|
|
|
|
|
|
|
@patch('torch.npu.graph_task_update_end', )
|
|
|
|
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
|
|
|
|
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
|
|
|
|
def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
|
|
|
|
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
|
|
|
|
num_heads = 256
|
|
|
|
|
scale = 0.1
|
|
|
|
|
num_kv_heads = 8
|
|
|
|
|
qk_head_dim = 96
|
|
|
|
|
qk_nope_head_dim = 64
|
|
|
|
|
query = torch.randn(4, num_heads, qk_head_dim)
|
|
|
|
|
q_nope = query[..., :qk_nope_head_dim]
|
|
|
|
|
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
|
|
|
|
actual_seq_lengths_kv = [1, 1]
|
|
|
|
|
actual_seq_lengths_q = np.array([1, 1])
|
|
|
|
|
out = torch.randn(2, 16, 128)
|
|
|
|
|
lse = torch.randn(2, 16, 8)
|
|
|
|
|
|
|
|
|
|
num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]],
|
|
|
|
|
[[1, 1], [1, 1]]])
|
|
|
|
|
decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp)
|
|
|
|
|
metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1],
|
|
|
|
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
|
|
|
|
num_decode_tokens=1,
|
|
|
|
|
decode_meta=decode)
|
|
|
|
|
forward_context = MagicMock()
|
|
|
|
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
2025-12-29 09:54:51 +08:00
|
|
|
forward_context.is_draft_model = False
|
2025-12-09 17:27:40 +08:00
|
|
|
|
|
|
|
|
self.graph_params.attn_params[4] = []
|
|
|
|
|
self.graph_params.attn_params[4].append(
|
|
|
|
|
(q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale,
|
|
|
|
|
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
|
|
|
|
|
out, lse, 2, 0, 0))
|
|
|
|
|
|
2025-12-15 19:54:23 +08:00
|
|
|
with patch("torch_npu._C._npu_setStream", return_value=None):
|
2026-01-26 09:04:54 +08:00
|
|
|
AscendAttentionCPImpl.update_graph_params(
|
|
|
|
|
self.update_stream, forward_context, 4, None
|
|
|
|
|
)
|
2025-12-09 17:27:40 +08:00
|
|
|
|
|
|
|
|
_mock_graph_task_end.assert_called_once()
|