2025-12-11 18:45:43 +08:00
|
|
|
import os
|
2025-07-28 15:54:40 +08:00
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
|
|
import torch
|
2025-11-20 20:29:09 +08:00
|
|
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
2025-07-28 15:54:40 +08:00
|
|
|
from vllm.distributed.parallel_state import GroupCoordinator
|
2025-12-19 14:27:24 +08:00
|
|
|
from vllm.model_executor.layers.linear import (LinearBase,
|
|
|
|
|
UnquantizedLinearMethod)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
from tests.ut.base import TestBase
|
2025-12-11 12:43:04 +08:00
|
|
|
from vllm_ascend.ascend_config import init_ascend_config
|
2025-07-28 15:54:40 +08:00
|
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
|
|
|
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
|
|
|
|
AscendMLADecodeMetadata,
|
|
|
|
|
AscendMLAImpl, AscendMLAMetadata,
|
|
|
|
|
AscendMLAMetadataBuilder,
|
2025-12-24 10:25:19 +08:00
|
|
|
AscendMLAPrefillMetadata,
|
|
|
|
|
ChunkedContextMetadata)
|
2025-11-20 20:29:09 +08:00
|
|
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendMLABackend(TestBase):
|
|
|
|
|
|
2026-01-13 19:14:43 +08:00
|
|
|
def setUp(self):
|
|
|
|
|
self.mock_config = MagicMock()
|
|
|
|
|
|
|
|
|
|
mock_parallel_config = MagicMock()
|
|
|
|
|
mock_parallel_config.prefill_context_parallel_size = 1
|
|
|
|
|
mock_parallel_config.decode_context_parallel_size = 1
|
|
|
|
|
|
|
|
|
|
self.mock_config.parallel_config = mock_parallel_config
|
|
|
|
|
|
|
|
|
|
self.utils_patcher = patch(
|
|
|
|
|
'vllm_ascend.attention.utils.get_current_vllm_config',
|
|
|
|
|
return_value=self.mock_config)
|
|
|
|
|
self.utils_patcher.start()
|
|
|
|
|
|
|
|
|
|
from vllm_ascend.attention.utils import enable_cp
|
|
|
|
|
enable_cp.cache_clear()
|
|
|
|
|
|
2025-07-28 15:54:40 +08:00
|
|
|
def test_get_name(self):
|
|
|
|
|
self.assertEqual(AscendMLABackend.get_name(), "ASCEND_MLA")
|
|
|
|
|
|
|
|
|
|
def test_get_builder_cls(self):
|
|
|
|
|
self.assertEqual(AscendMLABackend.get_builder_cls(),
|
|
|
|
|
AscendMLAMetadataBuilder)
|
|
|
|
|
|
|
|
|
|
def test_get_kv_cache_shape(self):
|
|
|
|
|
result = AscendMLABackend.get_kv_cache_shape(2, 4, 8, 128)
|
|
|
|
|
self.assertEqual(result, (2, 4, 8, 128))
|
|
|
|
|
|
|
|
|
|
def test_get_impl_cls(self):
|
|
|
|
|
result = AscendMLABackend.get_impl_cls()
|
|
|
|
|
self.assertEqual(result, AscendMLAImpl)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendMLAPrefillMetadata(TestBase):
|
|
|
|
|
|
|
|
|
|
def test_ascend_mla_prefill_metadata_default(self):
|
|
|
|
|
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
|
|
|
|
|
query_lens = [1, 2]
|
|
|
|
|
seq_lens = [2, 2]
|
|
|
|
|
context_lens = torch.tensor([1, 2])
|
|
|
|
|
input_positions = torch.tensor([0, 1, 0, 1])
|
|
|
|
|
query_start_loc = torch.tensor([0, 1, 3])
|
|
|
|
|
block_table = torch.tensor([[0, 1], [2, 3]])
|
|
|
|
|
max_query_len = 2
|
|
|
|
|
max_seq_lens = 2
|
|
|
|
|
|
|
|
|
|
metadata = AscendMLAPrefillMetadata(attn_mask=attn_mask,
|
|
|
|
|
query_lens=query_lens,
|
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
context_lens=context_lens,
|
|
|
|
|
input_positions=input_positions,
|
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
max_query_len=max_query_len,
|
|
|
|
|
max_seq_lens=max_seq_lens)
|
|
|
|
|
self.assertIs(metadata.attn_mask, attn_mask)
|
|
|
|
|
self.assertEqual(metadata.query_lens, query_lens)
|
|
|
|
|
self.assertEqual(metadata.seq_lens, seq_lens)
|
|
|
|
|
self.assertIs(metadata.context_lens, context_lens)
|
|
|
|
|
self.assertIs(metadata.input_positions, input_positions)
|
|
|
|
|
self.assertIs(metadata.query_start_loc, query_start_loc)
|
|
|
|
|
self.assertIs(metadata.block_table, block_table)
|
|
|
|
|
self.assertEqual(metadata.max_query_len, max_query_len)
|
|
|
|
|
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
|
|
|
|
self.assertIsNone(metadata.chunked_context)
|
|
|
|
|
|
|
|
|
|
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
|
|
|
|
|
cu_seq_lens = torch.tensor([0, 2, 4])
|
|
|
|
|
starts = torch.tensor([0, 2])
|
|
|
|
|
seq_tot = [2, 2]
|
|
|
|
|
max_seq_lens = [2, 2]
|
|
|
|
|
workspace = torch.randn(2, 4)
|
|
|
|
|
chunk_seq_lens = torch.tensor([2, 2])
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
chunked_context = ChunkedContextMetadata(
|
2025-07-28 15:54:40 +08:00
|
|
|
cu_seq_lens=cu_seq_lens,
|
|
|
|
|
starts=starts,
|
|
|
|
|
seq_tot=seq_tot,
|
|
|
|
|
max_seq_lens=max_seq_lens,
|
|
|
|
|
workspace=workspace,
|
2025-11-08 18:45:31 +08:00
|
|
|
chunk_seq_lens=chunk_seq_lens,
|
2025-12-24 10:25:19 +08:00
|
|
|
chunk_seq_lens_npu=chunk_seq_lens)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
metadata = AscendMLAPrefillMetadata(
|
|
|
|
|
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
|
|
|
|
query_lens=[1, 2],
|
|
|
|
|
seq_lens=[2, 2],
|
|
|
|
|
context_lens=torch.tensor([1, 2]),
|
|
|
|
|
input_positions=torch.tensor([0, 1, 0, 1]),
|
|
|
|
|
query_start_loc=torch.tensor([0, 1, 3]),
|
|
|
|
|
block_table=torch.tensor([[0, 1], [2, 3]]),
|
|
|
|
|
max_query_len=2,
|
|
|
|
|
max_seq_lens=2,
|
|
|
|
|
chunked_context=chunked_context)
|
|
|
|
|
|
|
|
|
|
self.assertIsNotNone(metadata.chunked_context)
|
|
|
|
|
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
|
|
|
|
|
self.assertIs(metadata.chunked_context.starts, starts)
|
|
|
|
|
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
|
|
|
|
|
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
|
|
|
|
self.assertIs(metadata.chunked_context.workspace, workspace)
|
|
|
|
|
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
2025-11-08 18:45:31 +08:00
|
|
|
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
|
|
|
|
chunk_seq_lens)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendMLADecodeMetadata(TestBase):
|
|
|
|
|
|
|
|
|
|
def test_ascend_mla_decode_metadata_default(self):
|
|
|
|
|
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
|
|
|
|
|
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
|
|
|
|
|
seq_lens = torch.tensor([[2], [3]])
|
|
|
|
|
max_seq_lens = 4
|
|
|
|
|
seq_lens_list = [2, 3]
|
|
|
|
|
attn_mask = None
|
2025-12-15 18:41:38 +08:00
|
|
|
cp_seq_len = torch.tensor([2, 3])
|
|
|
|
|
batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]])
|
|
|
|
|
|
|
|
|
|
metadata = AscendMLADecodeMetadata(input_positions=input_positions,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
max_seq_lens=max_seq_lens,
|
|
|
|
|
seq_lens_list=seq_lens_list,
|
|
|
|
|
attn_mask=attn_mask,
|
|
|
|
|
cp_seq_len=cp_seq_len,
|
|
|
|
|
batch_seq_mask=batch_seq_mask)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
self.assertIs(metadata.input_positions, input_positions)
|
|
|
|
|
self.assertIs(metadata.block_table, block_table)
|
|
|
|
|
self.assertIs(metadata.seq_lens, seq_lens)
|
|
|
|
|
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
|
|
|
|
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
|
|
|
|
|
self.assertIsNone(attn_mask)
|
2025-12-15 18:41:38 +08:00
|
|
|
self.assertIs(metadata.cp_seq_len, cp_seq_len)
|
|
|
|
|
self.assertIs(metadata.batch_seq_mask, batch_seq_mask)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendMLAMetadata(TestBase):
|
|
|
|
|
|
|
|
|
|
def test_ascend_mla_metadata_default(self):
|
2025-10-24 10:32:01 +08:00
|
|
|
num_actual_tokens_pcp_padded = 100
|
2025-07-28 15:54:40 +08:00
|
|
|
num_actual_tokens = 100
|
|
|
|
|
slot_mapping = torch.randn(100, 4, 1024)
|
|
|
|
|
query_start_loc = torch.tensor([1, 2, 3, 4])
|
|
|
|
|
seq_lens = [30, 50]
|
|
|
|
|
block_tables = torch.randint(0, 100, (100, 4))
|
|
|
|
|
|
|
|
|
|
num_decodes = 4
|
|
|
|
|
num_decode_tokens = 8
|
|
|
|
|
num_prefills = 8
|
|
|
|
|
|
|
|
|
|
num_input_tokens = 2
|
|
|
|
|
|
|
|
|
|
query_lens = None
|
|
|
|
|
head_dim = None
|
|
|
|
|
attn_mask = None
|
|
|
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
|
|
|
|
|
|
|
|
decode = None
|
|
|
|
|
prefill = None
|
|
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
metadata = AscendMLAMetadata(
|
|
|
|
|
num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping,
|
|
|
|
|
query_start_loc, seq_lens, block_tables, num_decodes,
|
|
|
|
|
num_decode_tokens, num_prefills, num_input_tokens, query_lens,
|
|
|
|
|
head_dim, attn_mask, attn_state, decode, prefill)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
|
|
|
|
self.assertIs(metadata.slot_mapping, slot_mapping)
|
|
|
|
|
self.assertIs(metadata.query_start_loc, query_start_loc)
|
|
|
|
|
self.assertEqual(metadata.seq_lens, seq_lens)
|
|
|
|
|
self.assertIs(metadata.block_tables, block_tables)
|
|
|
|
|
self.assertEqual(metadata.num_decodes, num_decodes)
|
|
|
|
|
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
|
|
|
|
|
self.assertEqual(metadata.num_prefills, num_prefills)
|
|
|
|
|
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
|
|
|
|
self.assertEqual(metadata.query_lens, query_lens)
|
|
|
|
|
self.assertEqual(metadata.head_dim, head_dim)
|
|
|
|
|
self.assertEqual(metadata.attn_mask, attn_mask)
|
|
|
|
|
self.assertEqual(metadata.attn_state, attn_state)
|
|
|
|
|
self.assertEqual(metadata.decode, decode)
|
|
|
|
|
self.assertEqual(metadata.prefill, prefill)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendMLAMetadataBuilder(TestBase):
|
|
|
|
|
|
2026-01-21 10:45:45 +08:00
|
|
|
def setUp(self):
|
|
|
|
|
# Mock parent class __init__ to avoid complex initialization,
|
|
|
|
|
# but still set the essential attributes that child class needs
|
|
|
|
|
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
|
|
|
|
device, metadata_cls, supports_dcp_with_varlen):
|
|
|
|
|
self.metadata_cls = metadata_cls
|
|
|
|
|
self.kv_cache_spec = kv_cache_spec
|
|
|
|
|
self.model_config = vllm_config.model_config
|
|
|
|
|
self.vllm_config = vllm_config
|
|
|
|
|
self.device = device
|
|
|
|
|
self.chunked_prefill_workspace_size = 128 * 1024
|
|
|
|
|
self.chunked_prefill_workspace = torch.empty(
|
|
|
|
|
(self.chunked_prefill_workspace_size,
|
|
|
|
|
vllm_config.model_config.get_head_size()),
|
|
|
|
|
dtype=vllm_config.model_config.dtype,
|
|
|
|
|
device=device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.parent_init_patcher = patch(
|
|
|
|
|
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
|
|
|
|
mock_parent_init)
|
|
|
|
|
self.parent_init_patcher.start()
|
|
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
|
self.parent_init_patcher.stop()
|
|
|
|
|
|
2026-01-05 09:05:45 +08:00
|
|
|
def test_ascend_mla_metadata_builder_default(self):
|
2025-08-20 09:01:04 +08:00
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.model_config.max_model_len = 1024
|
|
|
|
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
|
|
|
|
mock_vllm_config.model_config.dtype = torch.float16
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
2025-08-20 09:01:04 +08:00
|
|
|
mock_vllm_config.cache_config.block_size = 16
|
|
|
|
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
2025-10-27 09:58:23 +08:00
|
|
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
2025-08-20 09:01:04 +08:00
|
|
|
mock_device = 'cpu'
|
2025-07-28 15:54:40 +08:00
|
|
|
|
2025-09-18 14:05:33 +08:00
|
|
|
mock_vllm_config.speculative_config = None
|
|
|
|
|
|
|
|
|
|
ascend_config = MagicMock()
|
|
|
|
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
|
|
|
|
return_value=ascend_config):
|
|
|
|
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
|
|
|
|
mock_device)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(builder.block_size,
|
|
|
|
|
mock_vllm_config.cache_config.block_size)
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
builder.chunked_prefill_enabled,
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
2025-09-18 14:05:33 +08:00
|
|
|
|
2026-01-05 09:05:45 +08:00
|
|
|
def test_ascend_mla_metadata_builder_spec_decode(self):
|
2025-09-18 14:05:33 +08:00
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.model_config.max_model_len = 1024
|
|
|
|
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
|
|
|
|
mock_vllm_config.model_config.dtype = torch.float16
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
2025-09-18 14:05:33 +08:00
|
|
|
mock_vllm_config.cache_config.block_size = 16
|
|
|
|
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
2025-10-27 09:58:23 +08:00
|
|
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
2025-09-18 14:05:33 +08:00
|
|
|
mock_device = 'cpu'
|
|
|
|
|
|
|
|
|
|
mock_spec_config = MagicMock()
|
|
|
|
|
mock_spec_config.num_speculative_tokens = 3
|
|
|
|
|
mock_vllm_config.speculative_config = mock_spec_config
|
|
|
|
|
|
2025-07-28 15:54:40 +08:00
|
|
|
ascend_config = MagicMock()
|
|
|
|
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
|
|
|
|
return_value=ascend_config):
|
2025-09-16 01:17:42 +08:00
|
|
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
|
|
|
|
mock_device)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
self.assertEqual(builder.block_size,
|
|
|
|
|
mock_vllm_config.cache_config.block_size)
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
builder.chunked_prefill_enabled,
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
2026-01-07 17:09:52 +08:00
|
|
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
|
|
|
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
def test_ascend_mla_metadata_builder_build_full_graph(
|
2026-01-07 17:09:52 +08:00
|
|
|
self, mock_get_pcp_group, mock_get_pcp_group_mask,
|
|
|
|
|
mock_get_cos_and_sin_mla):
|
|
|
|
|
pcp_group = MagicMock()
|
|
|
|
|
pcp_group.world_size = 1
|
|
|
|
|
mock_get_pcp_group.return_value = pcp_group
|
|
|
|
|
mock_get_pcp_group_mask.return_value = pcp_group
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.model_config.max_model_len = 1024
|
|
|
|
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
|
|
|
|
mock_vllm_config.model_config.dtype = torch.float16
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_vllm_config.cache_config.block_size = 16
|
|
|
|
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
|
|
|
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
|
|
|
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_device = 'cpu'
|
2025-12-06 17:15:57 +08:00
|
|
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
|
|
|
|
mock_spec_config = MagicMock()
|
|
|
|
|
mock_spec_config.num_speculative_tokens = 1
|
|
|
|
|
mock_spec_config.disable_padded_drafter_batch = True
|
|
|
|
|
mock_vllm_config.speculative_config = mock_spec_config
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
|
|
|
|
mock_device)
|
|
|
|
|
common_metadata = MagicMock()
|
|
|
|
|
common_metadata.graph_pad_size = 8
|
|
|
|
|
common_metadata.num_reqs = 4
|
|
|
|
|
common_metadata.num_actual_tokens = 5
|
|
|
|
|
common_metadata.max_query_len = 5
|
|
|
|
|
common_metadata.seq_lens_cpu = torch.Tensor([9, 10, 8, 8]).int()
|
|
|
|
|
common_metadata.query_start_loc = torch.Tensor([0, 1, 2, 4, 5]).int()
|
|
|
|
|
common_metadata.query_start_loc_cpu = torch.Tensor([0, 1, 2, 4,
|
|
|
|
|
5]).int()
|
|
|
|
|
common_metadata.positions = torch.Tensor([1, 2, 3, 4, 5, 6]).int()
|
|
|
|
|
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
|
|
|
|
|
common_metadata.block_table_tensor = block_table
|
|
|
|
|
common_metadata.prefill_context_parallel_metadata = None
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]),
|
|
|
|
|
torch.Tensor([6, 6]))
|
|
|
|
|
metadata = builder.build(0, common_metadata)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
|
|
|
|
self.assertEqual(metadata.decode.actual_seq_lengths_q,
|
|
|
|
|
[1, 2, 4, 5, 6, 6, 7, 8])
|
|
|
|
|
self.assertEqual(metadata.decode.block_table.shape[0], 8)
|
|
|
|
|
|
2026-01-05 09:05:45 +08:00
|
|
|
def test_reorder_batch(self):
|
2025-07-28 15:54:40 +08:00
|
|
|
ascend_config = MagicMock()
|
2025-08-20 09:01:04 +08:00
|
|
|
|
|
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.model_config.max_model_len = 1024
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
|
|
|
|
mock_vllm_config.model_config.dtype = torch.float16
|
|
|
|
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
2025-08-20 09:01:04 +08:00
|
|
|
mock_vllm_config.cache_config.block_size = 16
|
|
|
|
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
2025-10-27 09:58:23 +08:00
|
|
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
2025-08-20 09:01:04 +08:00
|
|
|
mock_device = 'cpu'
|
|
|
|
|
|
2025-09-18 14:05:33 +08:00
|
|
|
mock_vllm_config.speculative_config = None
|
|
|
|
|
|
2025-07-28 15:54:40 +08:00
|
|
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
|
|
|
|
return_value=ascend_config):
|
2025-09-16 01:17:42 +08:00
|
|
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
|
|
|
|
mock_device)
|
2025-08-28 10:35:57 +08:00
|
|
|
builder.decode_threshold = 1
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
input_batch = MagicMock()
|
|
|
|
|
input_batch.req_ids = [0, 1, 2, 3]
|
|
|
|
|
|
|
|
|
|
scheduler_output = MagicMock()
|
|
|
|
|
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
|
|
|
|
|
scheduler_output.scheduled_spec_decode_tokens = {
|
|
|
|
|
0: [],
|
|
|
|
|
1: [1],
|
|
|
|
|
2: [],
|
|
|
|
|
3: []
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
input_batch.swap_states = MagicMock()
|
|
|
|
|
|
|
|
|
|
modified = builder.reorder_batch(input_batch, scheduler_output)
|
|
|
|
|
|
|
|
|
|
self.assertTrue(modified)
|
|
|
|
|
input_batch.swap_states.assert_called_once_with(1, 2)
|
|
|
|
|
|
2026-01-05 09:05:45 +08:00
|
|
|
def test_pad_actual_seq_lens_q_mtp_disable_pad(self):
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.model_config.max_model_len = 1024
|
|
|
|
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
|
|
|
|
mock_vllm_config.model_config.dtype = torch.float16
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_vllm_config.cache_config.block_size = 16
|
|
|
|
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
|
|
|
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
|
|
|
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_device = 'cpu'
|
|
|
|
|
mock_vllm_config.speculative_config = None
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
|
|
|
|
mock_device)
|
|
|
|
|
input_seq_lens = [1, 2, 4, 5]
|
|
|
|
|
expect_output = [1, 2, 4, 5, 6, 6, 7, 8]
|
|
|
|
|
num_reqs = 4
|
|
|
|
|
num_reqs_pad_size = 4
|
|
|
|
|
output_seq_lens = builder.pad_actual_seq_len_q_mtp_disable_pad(
|
|
|
|
|
num_reqs_pad_size, num_reqs, input_seq_lens)
|
|
|
|
|
self.assertEqual(output_seq_lens, expect_output)
|
|
|
|
|
|
2026-01-05 09:05:45 +08:00
|
|
|
def test_pad_actual_seq_lens_q_mtp_enable_pad(self):
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.model_config.max_model_len = 1024
|
|
|
|
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
|
|
|
|
mock_vllm_config.model_config.dtype = torch.float16
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_vllm_config.cache_config.block_size = 16
|
|
|
|
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
|
|
|
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
|
|
|
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
mock_device = 'cpu'
|
|
|
|
|
mock_vllm_config.speculative_config = None
|
|
|
|
|
|
|
|
|
|
common_metadata = MagicMock()
|
|
|
|
|
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
|
|
|
|
mock_device)
|
|
|
|
|
input_seq_lens = [2, 4, 6]
|
|
|
|
|
expect_output = [2, 4, 6, 8]
|
|
|
|
|
num_reqs = 3
|
|
|
|
|
num_reqs_pad_size = 1
|
|
|
|
|
output_seq_lens = builder.pad_actual_seq_len_q_mtp_enable_pad(
|
|
|
|
|
num_reqs_pad_size, num_reqs, input_seq_lens, common_metadata)
|
|
|
|
|
self.assertEqual(output_seq_lens, expect_output)
|
|
|
|
|
|
2025-07-28 15:54:40 +08:00
|
|
|
|
2025-11-20 20:29:09 +08:00
|
|
|
class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
2026-01-21 10:45:45 +08:00
|
|
|
# Mock parent class __init__ to avoid complex initialization,
|
|
|
|
|
# but still set the essential attributes that child class needs
|
|
|
|
|
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
|
|
|
|
device, metadata_cls, supports_dcp_with_varlen):
|
|
|
|
|
self.metadata_cls = metadata_cls
|
|
|
|
|
self.kv_cache_spec = kv_cache_spec
|
|
|
|
|
self.model_config = vllm_config.model_config
|
|
|
|
|
self.vllm_config = vllm_config
|
|
|
|
|
self.device = device
|
|
|
|
|
self.chunked_prefill_workspace_size = 128 * 1024
|
|
|
|
|
self.chunked_prefill_workspace = torch.empty(
|
|
|
|
|
(self.chunked_prefill_workspace_size,
|
|
|
|
|
vllm_config.model_config.get_head_size()),
|
|
|
|
|
dtype=vllm_config.model_config.dtype,
|
|
|
|
|
device=device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.parent_init_patcher = patch(
|
|
|
|
|
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
|
|
|
|
mock_parent_init)
|
|
|
|
|
self.parent_init_patcher.start()
|
|
|
|
|
|
2025-11-20 20:29:09 +08:00
|
|
|
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
|
|
|
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
|
|
|
|
|
mock_scheduler_config.max_num_seqs = 8
|
|
|
|
|
mock_scheduler_config.chunked_prefill_enabled = True
|
2026-01-21 10:45:45 +08:00
|
|
|
mock_scheduler_config.enable_chunked_prefill = True
|
2025-12-02 22:10:52 +08:00
|
|
|
self.mock_vllm_config.scheduler_config = mock_scheduler_config
|
2025-11-20 20:29:09 +08:00
|
|
|
self.mock_vllm_config.speculative_config = None
|
|
|
|
|
self.mock_device = torch.device("cpu")
|
2025-12-11 18:45:43 +08:00
|
|
|
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
|
|
|
|
|
"fake_weight")
|
|
|
|
|
model_config = ModelConfig(
|
|
|
|
|
model=fake_weight_path,
|
|
|
|
|
skip_tokenizer_init=True,
|
|
|
|
|
)
|
|
|
|
|
model_config.hf_text_config.head_dim = 128
|
|
|
|
|
model_config.hf_text_config.qk_rope_head_dim = 32
|
|
|
|
|
self.mock_vllm_config.model_config = model_config
|
2025-11-20 20:29:09 +08:00
|
|
|
self.kv_cache_spec = MagicMock()
|
|
|
|
|
self.kv_cache_spec.num_layers = 32
|
2026-01-23 09:45:08 +08:00
|
|
|
self.kv_cache_spec.head_size = 64
|
2025-11-20 20:29:09 +08:00
|
|
|
self.kv_cache_spec.num_heads = 32
|
|
|
|
|
|
2026-01-21 10:45:45 +08:00
|
|
|
def tearDown(self):
|
|
|
|
|
self.parent_init_patcher.stop()
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
2026-01-07 17:09:52 +08:00
|
|
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
|
|
|
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
2025-12-02 22:10:52 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
|
|
|
|
@patch("torch.Tensor.npu", new=lambda self: self)
|
|
|
|
|
@patch("torch.npu.is_available")
|
|
|
|
|
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
2026-01-07 17:09:52 +08:00
|
|
|
mock_zeros, mock_get_pcp_group,
|
|
|
|
|
mock_get_pcp_group_mask,
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla):
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_npu_available.return_value = False
|
2025-12-06 17:15:57 +08:00
|
|
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
2026-01-07 17:09:52 +08:00
|
|
|
pcp_group = MagicMock()
|
|
|
|
|
pcp_group.world_size = 1
|
|
|
|
|
mock_get_pcp_group.return_value = pcp_group
|
|
|
|
|
mock_get_pcp_group_mask.return_value = pcp_group
|
2025-11-20 20:29:09 +08:00
|
|
|
|
2025-12-02 22:10:52 +08:00
|
|
|
def zeros_override(*args, **kwargs):
|
|
|
|
|
kwargs.pop('pin_memory', None)
|
|
|
|
|
return mock_zeros._mock_wraps(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
mock_zeros.side_effect = zeros_override
|
2025-11-20 20:29:09 +08:00
|
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
|
|
|
query_start_loc=torch.tensor([0, 3, 7]),
|
|
|
|
|
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
|
|
|
|
seq_lens_cpu=torch.tensor([5, 6]),
|
|
|
|
|
num_reqs=2,
|
|
|
|
|
num_actual_tokens=10,
|
|
|
|
|
max_query_len=5,
|
|
|
|
|
decode_token_per_req=torch.tensor([1, 1]),
|
|
|
|
|
block_table_tensor=torch.zeros((10, 10)),
|
|
|
|
|
slot_mapping=torch.tensor(range(20)),
|
|
|
|
|
actual_seq_lengths_q=torch.tensor([0, 1]),
|
|
|
|
|
positions=torch.tensor([10, 10]),
|
|
|
|
|
attn_state=AscendAttentionState.PrefillNoCache,
|
|
|
|
|
num_computed_tokens_cpu=None,
|
2025-12-23 00:10:52 +08:00
|
|
|
seq_lens=None,
|
|
|
|
|
max_seq_len=6)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
base_inputs = {
|
|
|
|
|
"num_actual_tokens": 10,
|
|
|
|
|
"slot_mapping": torch.tensor(range(10)),
|
|
|
|
|
"query_start_loc": torch.tensor([0, 3, 7]),
|
|
|
|
|
"seq_lens": torch.tensor([5, 6]),
|
|
|
|
|
"block_tables": torch.zeros((10, 10)),
|
|
|
|
|
"num_prefills": 2,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
|
|
|
|
layer_names=["layer_0", "layer_1"],
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
device=self.mock_device)
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
|
|
|
|
torch.Tensor(10))
|
|
|
|
|
metadata = builder.build(1, common_attn_metadata)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
self.assertIsInstance(metadata, AscendMLAMetadata)
|
|
|
|
|
self.assertEqual(metadata.num_actual_tokens,
|
|
|
|
|
base_inputs["num_actual_tokens"])
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
|
|
|
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
2026-01-07 17:09:52 +08:00
|
|
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
|
|
|
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
2025-12-02 22:10:52 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
|
|
|
|
@patch("torch.Tensor.npu", new=lambda self: self)
|
|
|
|
|
@patch("torch.npu.is_available")
|
|
|
|
|
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
2026-01-07 17:09:52 +08:00
|
|
|
mock_zeros, mock_get_pcp_group,
|
|
|
|
|
mock_get_pcp_group_mask,
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla):
|
2025-12-02 22:10:52 +08:00
|
|
|
mock_npu_available.return_value = False
|
2025-12-06 17:15:57 +08:00
|
|
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
2026-01-07 17:09:52 +08:00
|
|
|
pcp_group = MagicMock()
|
|
|
|
|
pcp_group.world_size = 1
|
|
|
|
|
mock_get_pcp_group.return_value = pcp_group
|
|
|
|
|
mock_get_pcp_group_mask.return_value = pcp_group
|
2025-11-20 20:29:09 +08:00
|
|
|
|
2025-12-02 22:10:52 +08:00
|
|
|
def zeros_override(*args, **kwargs):
|
|
|
|
|
kwargs.pop('pin_memory', None)
|
|
|
|
|
return mock_zeros._mock_wraps(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
mock_zeros.side_effect = zeros_override
|
|
|
|
|
|
2025-11-20 20:29:09 +08:00
|
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
|
|
|
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
|
|
|
|
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
|
|
|
|
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
|
|
|
|
num_reqs=3,
|
|
|
|
|
num_actual_tokens=15,
|
|
|
|
|
max_query_len=6,
|
|
|
|
|
decode_token_per_req=torch.tensor([1, 1, 1]),
|
|
|
|
|
block_table_tensor=torch.zeros((10, 10)),
|
|
|
|
|
slot_mapping=torch.tensor(range(20)),
|
|
|
|
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
|
|
|
|
positions=torch.tensor([10, 10]),
|
|
|
|
|
attn_state=AscendAttentionState.ChunkedPrefill,
|
|
|
|
|
num_computed_tokens_cpu=None,
|
2025-12-23 00:10:52 +08:00
|
|
|
seq_lens=None,
|
|
|
|
|
max_seq_len=6)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
base_inputs = {
|
|
|
|
|
"num_actual_tokens": 15,
|
|
|
|
|
"slot_mapping": torch.tensor(range(15)),
|
|
|
|
|
"query_start_loc": torch.tensor([0, 2, 5, 9]),
|
|
|
|
|
"seq_lens": torch.tensor([4, 5, 6]),
|
|
|
|
|
"block_tables": torch.zeros((10, 10)),
|
|
|
|
|
"num_prefills": 3,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
|
|
|
|
layer_names=["layer_0", "layer_1"],
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
device=self.mock_device)
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
|
|
|
|
torch.Tensor(10))
|
|
|
|
|
metadata = builder.build(1, common_attn_metadata)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
self.assertIsInstance(metadata, AscendMLAMetadata)
|
|
|
|
|
self.assertEqual(metadata.num_actual_tokens,
|
|
|
|
|
base_inputs["num_actual_tokens"])
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
|
|
|
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
2026-01-07 17:09:52 +08:00
|
|
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
|
|
|
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
|
|
|
|
def test_build_decode_only_metadata(self, mock_get_pcp_group,
|
|
|
|
|
mock_get_pcp_group_mask,
|
|
|
|
|
mock_get_cos_and_sin_mla):
|
2025-12-06 17:15:57 +08:00
|
|
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
2026-01-07 17:09:52 +08:00
|
|
|
pcp_group = MagicMock()
|
|
|
|
|
pcp_group.world_size = 1
|
|
|
|
|
mock_get_pcp_group.return_value = pcp_group
|
|
|
|
|
mock_get_pcp_group_mask.return_value = pcp_group
|
2025-12-06 17:15:57 +08:00
|
|
|
|
2025-11-20 20:29:09 +08:00
|
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
|
|
|
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
|
|
|
|
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
|
|
|
|
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
|
|
|
|
num_reqs=3,
|
|
|
|
|
num_actual_tokens=3,
|
|
|
|
|
max_query_len=1,
|
|
|
|
|
block_table_tensor=torch.zeros((10, 10)),
|
|
|
|
|
slot_mapping=torch.tensor(range(3)),
|
|
|
|
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
|
|
|
|
decode_token_per_req=torch.tensor([1, 1, 1]),
|
|
|
|
|
positions=torch.tensor([10, 10]),
|
|
|
|
|
attn_state=AscendAttentionState.DecodeOnly,
|
|
|
|
|
num_computed_tokens_cpu=None,
|
2025-12-23 00:10:52 +08:00
|
|
|
seq_lens=None,
|
|
|
|
|
max_seq_len=6)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
base_inputs = {
|
|
|
|
|
"num_actual_tokens": 3,
|
|
|
|
|
"slot_mapping": torch.tensor(range(3)),
|
|
|
|
|
"query_start_loc": torch.tensor([0, 1, 2, 3]),
|
|
|
|
|
"seq_lens": torch.tensor([4, 5, 6]),
|
|
|
|
|
"num_decodes": 3,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
|
|
|
|
layer_names=["layer_0", "layer_1"],
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
device=self.mock_device)
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
|
|
|
|
|
torch.Tensor([10, 10]))
|
|
|
|
|
metadata = builder.build(1, common_attn_metadata)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
self.assertIsInstance(metadata, AscendMLAMetadata)
|
|
|
|
|
self.assertEqual(metadata.num_actual_tokens,
|
|
|
|
|
base_inputs["num_actual_tokens"])
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
|
|
|
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
2026-01-07 17:09:52 +08:00
|
|
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
|
|
|
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
|
|
|
|
def test_build_for_graph_capture_decode_only(self, mock_get_pcp_group,
|
|
|
|
|
mock_get_pcp_group_mask,
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla):
|
2025-12-06 17:15:57 +08:00
|
|
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
2026-01-07 17:09:52 +08:00
|
|
|
pcp_group = MagicMock()
|
|
|
|
|
pcp_group.world_size = 1
|
|
|
|
|
mock_get_pcp_group.return_value = pcp_group
|
|
|
|
|
mock_get_pcp_group_mask.return_value = pcp_group
|
2025-12-06 17:15:57 +08:00
|
|
|
|
2025-11-20 20:29:09 +08:00
|
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
|
|
|
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
|
|
|
|
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
|
|
|
|
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
|
|
|
|
num_reqs=3,
|
|
|
|
|
num_actual_tokens=3,
|
|
|
|
|
max_query_len=1,
|
|
|
|
|
block_table_tensor=torch.zeros((10, 10)),
|
|
|
|
|
slot_mapping=torch.tensor(range(3)),
|
|
|
|
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
|
|
|
|
decode_token_per_req=torch.tensor([1, 1, 1]),
|
|
|
|
|
positions=torch.tensor([10, 10]),
|
|
|
|
|
attn_state=AscendAttentionState.DecodeOnly,
|
|
|
|
|
num_computed_tokens_cpu=None,
|
2025-12-23 00:10:52 +08:00
|
|
|
seq_lens=None,
|
|
|
|
|
max_seq_len=6)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
base_inputs = {
|
|
|
|
|
"num_actual_tokens": 3,
|
|
|
|
|
"slot_mapping": torch.tensor(range(3)),
|
|
|
|
|
"query_start_loc": torch.tensor([0, 1, 2, 3]),
|
|
|
|
|
"seq_lens": torch.tensor([4, 5, 6]),
|
|
|
|
|
"num_decodes": 3,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
|
|
|
|
layer_names=["layer_0", "layer_1"],
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
device=self.mock_device)
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
|
|
|
|
|
torch.Tensor([10, 10]))
|
2025-11-20 20:29:09 +08:00
|
|
|
metadata = builder.build_for_graph_capture(
|
2025-12-28 10:35:07 +08:00
|
|
|
common_attn_metadata, AscendAttentionState.DecodeOnly)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
self.assertIsInstance(metadata, AscendMLAMetadata)
|
|
|
|
|
self.assertEqual(metadata.num_actual_tokens,
|
|
|
|
|
base_inputs["num_actual_tokens"])
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
|
|
|
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
2026-01-05 09:05:45 +08:00
|
|
|
def test_build_for_graph_capture_prefill(self, mock_get_cos_and_sin_mla):
|
2025-12-06 17:15:57 +08:00
|
|
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
2025-11-20 20:29:09 +08:00
|
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
|
|
|
query_start_loc=torch.tensor([0, 3, 7]),
|
|
|
|
|
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
|
|
|
|
seq_lens_cpu=torch.tensor([5, 6]),
|
|
|
|
|
num_reqs=2,
|
|
|
|
|
num_actual_tokens=10,
|
|
|
|
|
max_query_len=5,
|
|
|
|
|
decode_token_per_req=torch.tensor([1, 1]),
|
|
|
|
|
block_table_tensor=torch.zeros((10, 10)),
|
|
|
|
|
slot_mapping=torch.tensor(range(20)),
|
|
|
|
|
actual_seq_lengths_q=torch.tensor([0, 1]),
|
|
|
|
|
positions=torch.tensor([10, 10]),
|
|
|
|
|
attn_state=AscendAttentionState.PrefillNoCache,
|
|
|
|
|
num_computed_tokens_cpu=None,
|
2025-12-23 00:10:52 +08:00
|
|
|
seq_lens=None,
|
|
|
|
|
max_seq_len=6)
|
2025-11-20 20:29:09 +08:00
|
|
|
|
|
|
|
|
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
|
|
|
|
layer_names=["layer_0", "layer_1"],
|
|
|
|
|
vllm_config=self.mock_vllm_config,
|
|
|
|
|
device=self.mock_device)
|
2025-12-28 10:35:07 +08:00
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
|
|
|
|
torch.Tensor(10))
|
2025-11-20 20:29:09 +08:00
|
|
|
with self.assertRaises(NotImplementedError) as ctx:
|
|
|
|
|
builder.build_for_graph_capture(
|
2025-12-28 10:35:07 +08:00
|
|
|
common_attn_metadata, AscendAttentionState.PrefillNoCache)
|
2025-11-20 20:29:09 +08:00
|
|
|
self.assertIn(
|
|
|
|
|
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
|
|
|
|
|
str(ctx.exception))
|
|
|
|
|
|
|
|
|
|
|
2025-07-28 15:54:40 +08:00
|
|
|
class TestAscendMLAImpl(TestBase):
|
|
|
|
|
|
|
|
|
|
@patch('vllm.distributed.parallel_state._TP',
|
|
|
|
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
2025-09-04 10:22:46 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
2026-01-05 09:05:45 +08:00
|
|
|
def setUp(self, get_current_vllm_config, mock_tp):
|
2025-07-28 15:54:40 +08:00
|
|
|
mock_tp.world_size = 2
|
2025-10-24 10:32:01 +08:00
|
|
|
mock_tp.rank_in_group = MagicMock()
|
|
|
|
|
mock_tp.device_group = MagicMock()
|
2025-09-04 10:22:46 +08:00
|
|
|
vllm_config = MagicMock()
|
2025-07-28 15:54:40 +08:00
|
|
|
speculative_config = MagicMock()
|
2025-09-04 10:22:46 +08:00
|
|
|
model_config = MagicMock()
|
2025-07-28 15:54:40 +08:00
|
|
|
speculative_config.num_speculative_tokens = 4
|
|
|
|
|
vllm_config.speculative_config = speculative_config
|
2025-09-04 10:22:46 +08:00
|
|
|
model_config.dtype = torch.float16
|
|
|
|
|
vllm_config.model_config = model_config
|
|
|
|
|
get_current_vllm_config.return_value = vllm_config
|
2025-12-11 12:43:04 +08:00
|
|
|
vllm_config.additional_config = {"refresh": True}
|
|
|
|
|
init_ascend_config(vllm_config)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
|
|
|
|
num_heads = 256
|
|
|
|
|
head_size = 1024
|
|
|
|
|
scale = 0.1
|
|
|
|
|
num_kv_heads = 8
|
|
|
|
|
kv_cache_dtype = "auto"
|
|
|
|
|
|
|
|
|
|
kv_a_layernorm = MagicMock()
|
|
|
|
|
kv_a_layernorm.weight = torch.randn(96)
|
|
|
|
|
kv_a_layernorm.variance_epsilon = 1e-6
|
|
|
|
|
kwargs = {
|
|
|
|
|
"kv_lora_rank": 32,
|
|
|
|
|
"qk_nope_head_dim": 64,
|
|
|
|
|
"qk_rope_head_dim": 32,
|
|
|
|
|
"qk_head_dim": 96,
|
|
|
|
|
"v_head_dim": 128,
|
[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
1. refactor deepseek to the latest code arch as of
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
|
|
|
"q_lora_rank": 64,
|
2025-07-28 15:54:40 +08:00
|
|
|
"q_proj": MagicMock(),
|
2025-10-20 15:31:34 +08:00
|
|
|
"q_b_proj": MagicMock(),
|
2025-07-28 15:54:40 +08:00
|
|
|
"kv_b_proj": MagicMock(),
|
|
|
|
|
"o_proj": MagicMock(),
|
|
|
|
|
"kv_a_proj_with_mqa": MagicMock(),
|
2025-10-21 20:17:09 +08:00
|
|
|
"fused_qkv_a_proj": MagicMock(),
|
2025-07-28 15:54:40 +08:00
|
|
|
"kv_a_layernorm": kv_a_layernorm,
|
[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
1. refactor deepseek to the latest code arch as of
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
|
|
|
"rotary_emb": MagicMock(),
|
2025-07-28 15:54:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
self.impl = AscendMLAImpl(num_heads=num_heads,
|
|
|
|
|
head_size=head_size,
|
|
|
|
|
scale=scale,
|
|
|
|
|
num_kv_heads=num_kv_heads,
|
|
|
|
|
alibi_slopes=None,
|
|
|
|
|
sliding_window=None,
|
|
|
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
|
|
|
blocksparse_params=None,
|
|
|
|
|
logits_soft_cap=None,
|
|
|
|
|
attn_type=None,
|
|
|
|
|
kv_sharing_target_layer_name=None,
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
def test_init(self):
|
|
|
|
|
self.assertEqual(self.impl.num_heads, 256)
|
|
|
|
|
self.assertEqual(self.impl.head_size, 1024)
|
|
|
|
|
self.assertEqual(self.impl.scale, 0.1)
|
|
|
|
|
self.assertEqual(self.impl.num_kv_heads, 8)
|
|
|
|
|
self.assertEqual(self.impl.kv_cache_dtype, "auto")
|
|
|
|
|
self.assertEqual(self.impl.kv_lora_rank, 32)
|
|
|
|
|
self.assertEqual(self.impl.qk_nope_head_dim, 64)
|
|
|
|
|
self.assertEqual(self.impl.qk_rope_head_dim, 32)
|
|
|
|
|
self.assertEqual(self.impl.qk_head_dim, 96)
|
|
|
|
|
self.assertEqual(self.impl.v_head_dim, 128)
|
|
|
|
|
self.assertIsNotNone(self.impl.q_proj)
|
|
|
|
|
self.assertIsNotNone(self.impl.kv_b_proj)
|
|
|
|
|
self.assertIsNotNone(self.impl.o_proj)
|
|
|
|
|
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
|
|
|
|
|
self.assertIsNotNone(self.impl.kv_a_layernorm)
|
|
|
|
|
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
|
|
|
|
|
|
|
|
|
def test_q_proj_and_k_up_proj(self):
|
|
|
|
|
batch_size = 4
|
|
|
|
|
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
|
|
|
|
|
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_head_dim)
|
|
|
|
|
self.impl.q_proj.return_value = (q_proj_output, )
|
|
|
|
|
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
|
|
|
|
|
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
|
|
|
|
|
self.impl.qk_nope_head_dim,
|
|
|
|
|
self.impl.kv_lora_rank)
|
|
|
|
|
result = self.impl._q_proj_and_k_up_proj(x)
|
|
|
|
|
ql_nope, q_pe = result
|
|
|
|
|
self.assertEqual(ql_nope.shape[0], batch_size)
|
|
|
|
|
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
|
|
|
|
|
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
|
|
|
|
|
self.assertEqual(q_pe.shape[0], batch_size)
|
|
|
|
|
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
|
|
|
|
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
|
|
|
|
|
2025-10-14 17:39:26 +08:00
|
|
|
@patch('torch_npu.npu_format_cast')
|
|
|
|
|
def test_process_weights_after_loading(self, mock_format_cast):
|
2025-07-28 15:54:40 +08:00
|
|
|
layer = MagicMock(spec=LinearBase)
|
|
|
|
|
layer.input_size_per_partition = 10
|
2025-12-19 14:27:24 +08:00
|
|
|
quant_method = MagicMock(spec=UnquantizedLinearMethod)
|
2025-07-28 15:54:40 +08:00
|
|
|
layer.quant_method = quant_method
|
|
|
|
|
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
|
|
|
|
|
self.impl.v_head_dim)
|
|
|
|
|
shape_1 = self.impl.kv_lora_rank
|
|
|
|
|
layer.weight = torch.randn(shape_0, shape_1)
|
|
|
|
|
self.impl.kv_b_proj = layer
|
2025-10-14 17:39:26 +08:00
|
|
|
mock_format_cast.return_value = layer.weight
|
2025-07-28 15:54:40 +08:00
|
|
|
self.impl.process_weights_after_loading(torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
|
|
|
|
|
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
|
|
|
|
|
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
|
|
|
|
|
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
|
|
|
|
|
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
|
|
|
|
|
|
|
|
|
|
def test_compute_prefill_context_none(self):
|
|
|
|
|
batch_size = 4
|
|
|
|
|
kv_cache = torch.randn(10, 1, 1, 192)
|
|
|
|
|
query = torch.randn(batch_size, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_head_dim)
|
|
|
|
|
metadata = MagicMock()
|
|
|
|
|
metadata.prefill = None
|
|
|
|
|
prefix_out = torch.randn(2, 16, 128)
|
|
|
|
|
prefix_lse = torch.randn(2, 16, 8)
|
2025-08-28 10:35:57 +08:00
|
|
|
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
|
|
|
|
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
|
|
|
|
|
|
|
|
|
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
|
|
|
|
32, metadata, prefix_out,
|
2025-07-28 15:54:40 +08:00
|
|
|
prefix_lse)
|
|
|
|
|
|
|
|
|
|
self.assertTrue(torch.equal(prefix_out, out))
|
|
|
|
|
self.assertTrue(torch.equal(prefix_lse, lse))
|
|
|
|
|
|
|
|
|
|
@patch("torch_npu.atb.npu_paged_cache_load")
|
|
|
|
|
@patch("torch_npu.atb.npu_ring_mla")
|
|
|
|
|
def test_compute_prefill_context(self, mock_ring, mock_load):
|
|
|
|
|
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
|
|
|
|
|
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
|
|
|
|
|
latent_kv_dim = self.impl.kv_lora_rank
|
|
|
|
|
num_blocks, block_size = 100, 20
|
|
|
|
|
query = torch.randn(S, N, D)
|
2025-08-28 10:35:57 +08:00
|
|
|
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
|
|
|
|
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
2025-07-28 15:54:40 +08:00
|
|
|
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
|
|
|
|
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
|
|
|
|
kv_cache = [kv_cache_0, kv_cache_1]
|
|
|
|
|
prefix_out = torch.randn(S, N, 128)
|
|
|
|
|
prefix_lse = torch.randn(S, N)
|
|
|
|
|
|
|
|
|
|
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
|
|
|
|
|
|
|
|
|
|
chunk_ctx = MagicMock()
|
|
|
|
|
chunk_ctx.seq_tot = [8]
|
|
|
|
|
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
2025-11-08 18:45:31 +08:00
|
|
|
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
|
2025-07-28 15:54:40 +08:00
|
|
|
chunk_ctx.starts = [torch.tensor([0])]
|
|
|
|
|
|
|
|
|
|
prefill_meta = MagicMock()
|
|
|
|
|
prefill_meta.chunked_context = chunk_ctx
|
|
|
|
|
prefill_meta.query_lens = [8]
|
|
|
|
|
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
|
|
|
|
|
|
|
|
|
|
meta = MagicMock()
|
|
|
|
|
meta.prefill = prefill_meta
|
2025-08-28 10:35:57 +08:00
|
|
|
self.impl.prefill_mask = torch.triu(
|
|
|
|
|
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
|
2025-07-28 15:54:40 +08:00
|
|
|
|
2025-08-28 10:35:57 +08:00
|
|
|
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
|
|
|
|
32, meta, prefix_out,
|
2025-07-28 15:54:40 +08:00
|
|
|
prefix_lse)
|
|
|
|
|
|
|
|
|
|
mock_load.assert_called_once()
|
|
|
|
|
mock_ring.assert_called_once()
|
|
|
|
|
|
|
|
|
|
self.assertEqual(out.shape, prefix_out.shape)
|
|
|
|
|
self.assertEqual(lse.shape, prefix_lse.shape)
|
|
|
|
|
|
2025-10-10 16:31:20 +08:00
|
|
|
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
2025-08-28 10:35:57 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
|
|
|
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
|
|
|
|
def test_forward_decode_without_graph(self,
|
|
|
|
|
mock_npu_fused_infer_attention_score,
|
2025-10-10 16:31:20 +08:00
|
|
|
mock_up_proj,
|
|
|
|
|
mock_get_forward_context):
|
2025-07-28 15:54:40 +08:00
|
|
|
num_tokens = 100
|
|
|
|
|
block_size = 4
|
|
|
|
|
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_nope_head_dim)
|
|
|
|
|
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_rope_head_dim)
|
2025-08-28 10:35:57 +08:00
|
|
|
k_nope = torch.randn(num_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_nope_head_dim)
|
|
|
|
|
k_pe = torch.randn(num_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_rope_head_dim)
|
2025-07-28 15:54:40 +08:00
|
|
|
metadata = MagicMock()
|
|
|
|
|
metadata.decode = MagicMock()
|
|
|
|
|
metadata.decode.block_table = MagicMock()
|
|
|
|
|
metadata.decode.seq_lens = 10
|
2025-08-28 10:35:57 +08:00
|
|
|
mock_npu_fused_infer_attention_score.return_value = [
|
|
|
|
|
torch.randn(num_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.kv_lora_rank), None
|
|
|
|
|
]
|
2025-07-28 15:54:40 +08:00
|
|
|
mock_up_proj.return_value = torch.randn(num_tokens,
|
|
|
|
|
self.impl.num_heads,
|
|
|
|
|
self.impl.v_head_dim)
|
2025-10-10 16:31:20 +08:00
|
|
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
2025-08-28 10:35:57 +08:00
|
|
|
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
|
|
|
|
|
block_size, metadata)
|
2025-07-28 15:54:40 +08:00
|
|
|
self.assertEqual(result.shape[0], num_tokens)
|
|
|
|
|
self.assertEqual(result.shape[1], self.impl.num_heads)
|
|
|
|
|
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
|
|
|
|
mock_up_proj.assert_called_once()
|
2025-08-28 10:35:57 +08:00
|
|
|
mock_npu_fused_infer_attention_score.assert_called_once()
|
2025-09-01 14:07:57 +08:00
|
|
|
|
2025-10-15 19:36:32 +08:00
|
|
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
2025-10-09 20:38:39 +08:00
|
|
|
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
2025-10-15 19:36:32 +08:00
|
|
|
def test_mla_preprocess(self, magic_npu_fetch,
|
|
|
|
|
mock_maybe_all_gather_and_maybe_unpad):
|
2025-09-01 14:07:57 +08:00
|
|
|
magic_npu_fetch.return_value = MagicMock()
|
2025-10-15 19:36:32 +08:00
|
|
|
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
|
2025-09-01 14:07:57 +08:00
|
|
|
batch_size = 4
|
|
|
|
|
seq_len = 8
|
|
|
|
|
hidden_size = 1024
|
|
|
|
|
hidden_states = torch.randn(batch_size * seq_len, hidden_size)
|
|
|
|
|
|
|
|
|
|
kv_cache = MagicMock()
|
|
|
|
|
|
|
|
|
|
attn_metadata = MagicMock()
|
|
|
|
|
attn_metadata.num_decodes = 2
|
|
|
|
|
attn_metadata.num_prefills = 2
|
|
|
|
|
attn_metadata.num_decode_tokens = 2
|
|
|
|
|
attn_metadata.num_actual_tokens = 4
|
|
|
|
|
num_prefill_tokens = 2
|
|
|
|
|
attn_metadata.slot_mapping = torch.arange(4)
|
|
|
|
|
attn_metadata.decode.cos = torch.randn(2, 64)
|
|
|
|
|
attn_metadata.decode.sin = torch.randn(2, 64)
|
|
|
|
|
attn_metadata.prefill.cos = torch.randn(2, 64)
|
|
|
|
|
attn_metadata.prefill.sin = torch.randn(2, 64)
|
|
|
|
|
|
|
|
|
|
self.impl.q_a_layernorm = MagicMock()
|
|
|
|
|
self.impl.q_a_layernorm.return_value = torch.randn(
|
|
|
|
|
attn_metadata.num_actual_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_rope_head_dim)
|
|
|
|
|
self.impl.kv_a_proj_with_mqa = MagicMock()
|
|
|
|
|
self.impl.kv_a_proj_with_mqa.return_value = [
|
|
|
|
|
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
2025-10-21 20:17:09 +08:00
|
|
|
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
|
|
|
|
|
]
|
|
|
|
|
self.impl.fused_qkv_a_proj = MagicMock()
|
|
|
|
|
self.impl.fused_qkv_a_proj.return_value = [
|
|
|
|
|
torch.randn(
|
|
|
|
|
num_prefill_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank +
|
|
|
|
|
self.impl.q_lora_rank)
|
2025-09-01 14:07:57 +08:00
|
|
|
]
|
|
|
|
|
self.impl.q_proj = MagicMock()
|
|
|
|
|
self.impl.q_proj.return_value = [
|
|
|
|
|
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_head_dim)
|
|
|
|
|
]
|
|
|
|
|
self.impl.kv_b_proj = MagicMock()
|
|
|
|
|
self.impl.kv_b_proj.return_value = [
|
|
|
|
|
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
|
|
|
|
|
]
|
|
|
|
|
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
|
|
|
|
|
self.impl.exec_kv_decode = MagicMock()
|
|
|
|
|
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
|
|
|
|
|
self.impl.exec_kv_prefill = MagicMock()
|
|
|
|
|
self.impl.exec_kv_prefill.return_value = [
|
|
|
|
|
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.qk_rope_head_dim),
|
|
|
|
|
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
|
|
|
|
self.impl.kv_lora_rank)
|
|
|
|
|
]
|
|
|
|
|
self.impl._q_proj_and_k_up_proj = MagicMock()
|
|
|
|
|
self.impl._q_proj_and_k_up_proj.return_value = [
|
|
|
|
|
MagicMock(), MagicMock()
|
|
|
|
|
]
|
|
|
|
|
self.impl.num_kv_heads = self.impl.num_heads
|
2025-12-31 15:09:01 +08:00
|
|
|
self.impl.is_kv_producer = False
|
2025-09-01 14:07:57 +08:00
|
|
|
|
|
|
|
|
decode_res, prefill_res = self.impl._mla_preprocess(
|
2025-09-23 14:25:05 +08:00
|
|
|
"mock_layer",
|
|
|
|
|
hidden_states,
|
|
|
|
|
kv_cache,
|
|
|
|
|
attn_metadata,
|
|
|
|
|
need_gather_q_kv=False)
|
2025-09-01 14:07:57 +08:00
|
|
|
|
|
|
|
|
self.assertIsNotNone(decode_res)
|
|
|
|
|
self.assertIsNotNone(prefill_res)
|
|
|
|
|
|
|
|
|
|
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
|
|
|
|
def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache):
|
|
|
|
|
B = 2
|
|
|
|
|
N = self.impl.num_kv_heads
|
|
|
|
|
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
|
|
|
|
|
kv_no_split = torch.randn(B, N, D)
|
|
|
|
|
self.impl.enable_kv_nz = None
|
|
|
|
|
self.impl.kv_a_layernorm.weight = MagicMock()
|
|
|
|
|
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
|
|
|
|
|
cos = MagicMock()
|
|
|
|
|
sin = MagicMock()
|
|
|
|
|
slots = MagicMock()
|
|
|
|
|
kv_cache = [MagicMock(), MagicMock()]
|
|
|
|
|
|
|
|
|
|
mock_kv_rmsnorm_rope_cache.return_value = [
|
|
|
|
|
None, None,
|
|
|
|
|
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
|
|
|
|
|
torch.randn(B, N, 1, self.impl.kv_lora_rank)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin,
|
|
|
|
|
kv_cache, slots)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
|
|
|
|
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
|
|
|
|
|
|
|
|
|
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
|
|
|
|
def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache):
|
|
|
|
|
B = 2
|
|
|
|
|
N = self.impl.num_kv_heads
|
|
|
|
|
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
|
|
|
|
|
kv_no_split = torch.randn(B, N, D)
|
|
|
|
|
self.impl.enable_kv_nz = None
|
|
|
|
|
self.impl.kv_a_layernorm.weight = MagicMock()
|
|
|
|
|
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
|
|
|
|
|
cos = MagicMock()
|
|
|
|
|
sin = MagicMock()
|
|
|
|
|
slots = MagicMock()
|
|
|
|
|
kv_cache = [MagicMock(), MagicMock()]
|
|
|
|
|
|
|
|
|
|
mock_kv_rmsnorm_rope_cache.return_value = [
|
|
|
|
|
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
|
|
|
|
|
torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin,
|
|
|
|
|
kv_cache, slots)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
|
|
|
|
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
|
|
|
|
|
2025-10-10 16:31:20 +08:00
|
|
|
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
2025-09-01 14:07:57 +08:00
|
|
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
|
|
|
|
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
2025-10-10 16:31:20 +08:00
|
|
|
mock_get_forward_context):
|
2025-09-01 14:07:57 +08:00
|
|
|
B = 2
|
|
|
|
|
N = self.impl.num_kv_heads
|
|
|
|
|
BS = 100
|
|
|
|
|
HD = self.impl.v_head_dim
|
|
|
|
|
self.impl.kv_lora_rank = 256
|
|
|
|
|
self.impl.spec_token_num = 1
|
|
|
|
|
self.impl._v_up_proj = MagicMock()
|
|
|
|
|
self.impl._v_up_proj.return_value = torch.randn(B, N, HD)
|
|
|
|
|
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
|
|
|
|
|
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
|
|
|
|
|
k_nope = torch.randn(BS, N, self.impl.kv_lora_rank)
|
|
|
|
|
k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim)
|
|
|
|
|
attn_metadata = MagicMock()
|
|
|
|
|
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
|
|
|
|
|
attn_metadata.decode = MagicMock()
|
|
|
|
|
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
|
|
|
|
|
attn_metadata.decode.seq_lens_list = MagicMock()
|
|
|
|
|
self.impl.enable_kv_nz = True
|
|
|
|
|
|
|
|
|
|
mock_npu_fused_infer_attention_score.return_value = [
|
|
|
|
|
torch.randn(B, N, self.impl.kv_lora_rank), None
|
|
|
|
|
]
|
2025-10-10 16:31:20 +08:00
|
|
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
2025-09-01 14:07:57 +08:00
|
|
|
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
|
|
|
|
attn_metadata)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(result.shape[0], B)
|
|
|
|
|
self.assertEqual(result.shape[1], N)
|
|
|
|
|
self.assertEqual(result.shape[2], HD)
|