[UT]add pcp aclgraph ut (#4804)
### What this PR does / why we need it?
add pcp aclgraph ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -15,15 +15,21 @@
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, ForwardContext
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||
AscendMetadataForDecode)
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAMetadata)
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params,
|
||||
update_mtp_graph_params_workspaces)
|
||||
ACLGraphEntry, ACLGraphWrapper, get_graph_params, get_mtp_graph_params,
|
||||
set_graph_params, set_mtp_graph_params, update_attn_dcp_pcp_params,
|
||||
update_mla_attn_dcp_pcp_params, update_mtp_graph_params_workspaces)
|
||||
|
||||
|
||||
class TestACLGraphEntry(TestBase):
|
||||
@@ -726,3 +732,116 @@ class TestMTPGraphParams(TestBase):
|
||||
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
|
||||
graph_params = get_mtp_graph_params()
|
||||
self.assertIs(mtp_graph_params_mock, graph_params)
|
||||
|
||||
|
||||
class TestPCPDCPGraphParams(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.update_stream = MagicMock(name="FakeStream")
|
||||
graph_params = get_graph_params()
|
||||
if graph_params is None:
|
||||
set_graph_params(set([4]))
|
||||
self.graph_params = get_graph_params()
|
||||
else:
|
||||
self.graph_params = graph_params
|
||||
mock_event = torch.npu.ExternalEvent()
|
||||
mock_event.record = MagicMock()
|
||||
self.graph_params.events[4] = []
|
||||
self.graph_params.handles[4] = []
|
||||
self.graph_params.events[4].append(mock_event)
|
||||
self.graph_params.handles[4].append(MagicMock())
|
||||
|
||||
@patch('torch.npu.graph_task_update_end', )
|
||||
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
|
||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
|
||||
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||
seq_lens = torch.tensor([4, 4])
|
||||
cp_seq_len = torch.tensor([2, 2])
|
||||
max_seq_lens = 4
|
||||
seq_lens_list = [4, 4]
|
||||
slot_mapping = torch.zeros(8, dtype=torch.long)
|
||||
query_start_loc = torch.tensor([0, 4])
|
||||
block_tables = torch.zeros(2, 5, dtype=torch.long)
|
||||
|
||||
decode = AscendMLADecodeMetadata(input_positions,
|
||||
block_table,
|
||||
seq_lens,
|
||||
max_seq_lens,
|
||||
seq_lens_list,
|
||||
cp_seq_len=cp_seq_len)
|
||||
metadata = AscendMLAMetadata(8,
|
||||
8,
|
||||
slot_mapping,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
decode=decode)
|
||||
forward_context = MagicMock()
|
||||
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||
forward_context.is_mtp_model = False
|
||||
|
||||
num_heads = 256
|
||||
scale = 0.1
|
||||
num_kv_heads = 8
|
||||
qk_head_dim = 96
|
||||
qk_rope_head_dim = 32
|
||||
qk_nope_head_dim = 64
|
||||
query = torch.randn(4, num_heads, qk_head_dim)
|
||||
q_pe = query[..., qk_nope_head_dim:]
|
||||
q_nope = query[..., :qk_nope_head_dim]
|
||||
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
||||
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
|
||||
out = torch.randn(2, 16, 128)
|
||||
lse = torch.randn(2, 16, 8)
|
||||
self.graph_params.attn_params[4] = []
|
||||
self.graph_params.attn_params[4].append(
|
||||
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
|
||||
scale, num_kv_heads, out, lse))
|
||||
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
|
||||
|
||||
_mock_graph_task_end.assert_called_once()
|
||||
|
||||
@patch('torch.npu.graph_task_update_end', )
|
||||
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
||||
def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
|
||||
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||
num_heads = 256
|
||||
scale = 0.1
|
||||
num_kv_heads = 8
|
||||
qk_head_dim = 96
|
||||
qk_nope_head_dim = 64
|
||||
query = torch.randn(4, num_heads, qk_head_dim)
|
||||
q_nope = query[..., :qk_nope_head_dim]
|
||||
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
||||
actual_seq_lengths_kv = [1, 1]
|
||||
actual_seq_lengths_q = np.array([1, 1])
|
||||
out = torch.randn(2, 16, 128)
|
||||
lse = torch.randn(2, 16, 8)
|
||||
|
||||
num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]],
|
||||
[[1, 1], [1, 1]]])
|
||||
decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp)
|
||||
metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1],
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
num_decode_tokens=1,
|
||||
decode_meta=decode)
|
||||
forward_context = MagicMock()
|
||||
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||
forward_context.is_mtp_model = False
|
||||
|
||||
self.graph_params.attn_params[4] = []
|
||||
self.graph_params.attn_params[4].append(
|
||||
(q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale,
|
||||
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
|
||||
out, lse, 2, 0, 0))
|
||||
|
||||
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
|
||||
|
||||
_mock_graph_task_end.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user