[UT]add pcp aclgraph ut (#4804)

### What this PR does / why we need it?
add pcp aclgraph ut

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-12-09 17:27:40 +08:00
committed by GitHub
parent c68dfa70ac
commit 49e346c6a6

View File

@@ -15,15 +15,21 @@
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import numpy as np
import torch import torch
from vllm.compilation.cuda_graph import CUDAGraphOptions from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, ForwardContext from vllm.forward_context import BatchDescriptor, ForwardContext
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
AscendMetadataForDecode)
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAMetadata)
from vllm_ascend.compilation.acl_graph import ( from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params, ACLGraphEntry, ACLGraphWrapper, get_graph_params, get_mtp_graph_params,
update_mtp_graph_params_workspaces) set_graph_params, set_mtp_graph_params, update_attn_dcp_pcp_params,
update_mla_attn_dcp_pcp_params, update_mtp_graph_params_workspaces)
class TestACLGraphEntry(TestBase): class TestACLGraphEntry(TestBase):
@@ -726,3 +732,116 @@ class TestMTPGraphParams(TestBase):
def test_get_mtp_graph_params(self, mtp_graph_params_mock): def test_get_mtp_graph_params(self, mtp_graph_params_mock):
graph_params = get_mtp_graph_params() graph_params = get_mtp_graph_params()
self.assertIs(mtp_graph_params_mock, graph_params) self.assertIs(mtp_graph_params_mock, graph_params)
class TestPCPDCPGraphParams(TestBase):
def setUp(self):
self.update_stream = MagicMock(name="FakeStream")
graph_params = get_graph_params()
if graph_params is None:
set_graph_params(set([4]))
self.graph_params = get_graph_params()
else:
self.graph_params = graph_params
mock_event = torch.npu.ExternalEvent()
mock_event.record = MagicMock()
self.graph_params.events[4] = []
self.graph_params.handles[4] = []
self.graph_params.events[4].append(mock_event)
self.graph_params.handles[4].append(MagicMock())
@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
seq_lens = torch.tensor([4, 4])
cp_seq_len = torch.tensor([2, 2])
max_seq_lens = 4
seq_lens_list = [4, 4]
slot_mapping = torch.zeros(8, dtype=torch.long)
query_start_loc = torch.tensor([0, 4])
block_tables = torch.zeros(2, 5, dtype=torch.long)
decode = AscendMLADecodeMetadata(input_positions,
block_table,
seq_lens,
max_seq_lens,
seq_lens_list,
cp_seq_len=cp_seq_len)
metadata = AscendMLAMetadata(8,
8,
slot_mapping,
query_start_loc,
seq_lens,
block_tables,
4,
4,
0,
decode=decode)
forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_mtp_model = False
num_heads = 256
scale = 0.1
num_kv_heads = 8
qk_head_dim = 96
qk_rope_head_dim = 32
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_pe = query[..., qk_nope_head_dim:]
q_nope = query[..., :qk_nope_head_dim]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
scale, num_kv_heads, out, lse))
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
_mock_graph_task_end.assert_called_once()
@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
block_table = torch.zeros(2, 5, dtype=torch.long)
num_heads = 256
scale = 0.1
num_kv_heads = 8
qk_head_dim = 96
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_nope = query[..., :qk_nope_head_dim]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
actual_seq_lengths_kv = [1, 1]
actual_seq_lengths_q = np.array([1, 1])
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]],
[[1, 1], [1, 1]]])
decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp)
metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1],
actual_seq_lengths_q=actual_seq_lengths_q,
num_decode_tokens=1,
decode_meta=decode)
forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_mtp_model = False
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale,
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
out, lse, 2, 0, 0))
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
_mock_graph_task_end.assert_called_once()