[UT]add pcp aclgraph ut (#4804)
### What this PR does / why we need it?
add pcp aclgraph ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -15,15 +15,21 @@
|
|||||||
|
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.forward_context import BatchDescriptor, ForwardContext
|
from vllm.forward_context import BatchDescriptor, ForwardContext
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||||
|
AscendMetadataForDecode)
|
||||||
|
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||||
|
AscendMLAMetadata)
|
||||||
from vllm_ascend.compilation.acl_graph import (
|
from vllm_ascend.compilation.acl_graph import (
|
||||||
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params,
|
ACLGraphEntry, ACLGraphWrapper, get_graph_params, get_mtp_graph_params,
|
||||||
update_mtp_graph_params_workspaces)
|
set_graph_params, set_mtp_graph_params, update_attn_dcp_pcp_params,
|
||||||
|
update_mla_attn_dcp_pcp_params, update_mtp_graph_params_workspaces)
|
||||||
|
|
||||||
|
|
||||||
class TestACLGraphEntry(TestBase):
|
class TestACLGraphEntry(TestBase):
|
||||||
@@ -726,3 +732,116 @@ class TestMTPGraphParams(TestBase):
|
|||||||
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
|
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
|
||||||
graph_params = get_mtp_graph_params()
|
graph_params = get_mtp_graph_params()
|
||||||
self.assertIs(mtp_graph_params_mock, graph_params)
|
self.assertIs(mtp_graph_params_mock, graph_params)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPCPDCPGraphParams(TestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.update_stream = MagicMock(name="FakeStream")
|
||||||
|
graph_params = get_graph_params()
|
||||||
|
if graph_params is None:
|
||||||
|
set_graph_params(set([4]))
|
||||||
|
self.graph_params = get_graph_params()
|
||||||
|
else:
|
||||||
|
self.graph_params = graph_params
|
||||||
|
mock_event = torch.npu.ExternalEvent()
|
||||||
|
mock_event.record = MagicMock()
|
||||||
|
self.graph_params.events[4] = []
|
||||||
|
self.graph_params.handles[4] = []
|
||||||
|
self.graph_params.events[4].append(mock_event)
|
||||||
|
self.graph_params.handles[4].append(MagicMock())
|
||||||
|
|
||||||
|
@patch('torch.npu.graph_task_update_end', )
|
||||||
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||||
|
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
|
||||||
|
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
|
||||||
|
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||||
|
seq_lens = torch.tensor([4, 4])
|
||||||
|
cp_seq_len = torch.tensor([2, 2])
|
||||||
|
max_seq_lens = 4
|
||||||
|
seq_lens_list = [4, 4]
|
||||||
|
slot_mapping = torch.zeros(8, dtype=torch.long)
|
||||||
|
query_start_loc = torch.tensor([0, 4])
|
||||||
|
block_tables = torch.zeros(2, 5, dtype=torch.long)
|
||||||
|
|
||||||
|
decode = AscendMLADecodeMetadata(input_positions,
|
||||||
|
block_table,
|
||||||
|
seq_lens,
|
||||||
|
max_seq_lens,
|
||||||
|
seq_lens_list,
|
||||||
|
cp_seq_len=cp_seq_len)
|
||||||
|
metadata = AscendMLAMetadata(8,
|
||||||
|
8,
|
||||||
|
slot_mapping,
|
||||||
|
query_start_loc,
|
||||||
|
seq_lens,
|
||||||
|
block_tables,
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
0,
|
||||||
|
decode=decode)
|
||||||
|
forward_context = MagicMock()
|
||||||
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||||
|
forward_context.is_mtp_model = False
|
||||||
|
|
||||||
|
num_heads = 256
|
||||||
|
scale = 0.1
|
||||||
|
num_kv_heads = 8
|
||||||
|
qk_head_dim = 96
|
||||||
|
qk_rope_head_dim = 32
|
||||||
|
qk_nope_head_dim = 64
|
||||||
|
query = torch.randn(4, num_heads, qk_head_dim)
|
||||||
|
q_pe = query[..., qk_nope_head_dim:]
|
||||||
|
q_nope = query[..., :qk_nope_head_dim]
|
||||||
|
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
||||||
|
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
|
||||||
|
out = torch.randn(2, 16, 128)
|
||||||
|
lse = torch.randn(2, 16, 8)
|
||||||
|
self.graph_params.attn_params[4] = []
|
||||||
|
self.graph_params.attn_params[4].append(
|
||||||
|
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
|
||||||
|
scale, num_kv_heads, out, lse))
|
||||||
|
|
||||||
|
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
|
||||||
|
|
||||||
|
_mock_graph_task_end.assert_called_once()
|
||||||
|
|
||||||
|
@patch('torch.npu.graph_task_update_end', )
|
||||||
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||||
|
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
||||||
|
def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
|
||||||
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||||
|
num_heads = 256
|
||||||
|
scale = 0.1
|
||||||
|
num_kv_heads = 8
|
||||||
|
qk_head_dim = 96
|
||||||
|
qk_nope_head_dim = 64
|
||||||
|
query = torch.randn(4, num_heads, qk_head_dim)
|
||||||
|
q_nope = query[..., :qk_nope_head_dim]
|
||||||
|
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
||||||
|
actual_seq_lengths_kv = [1, 1]
|
||||||
|
actual_seq_lengths_q = np.array([1, 1])
|
||||||
|
out = torch.randn(2, 16, 128)
|
||||||
|
lse = torch.randn(2, 16, 8)
|
||||||
|
|
||||||
|
num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]],
|
||||||
|
[[1, 1], [1, 1]]])
|
||||||
|
decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp)
|
||||||
|
metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1],
|
||||||
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
|
num_decode_tokens=1,
|
||||||
|
decode_meta=decode)
|
||||||
|
forward_context = MagicMock()
|
||||||
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||||
|
forward_context.is_mtp_model = False
|
||||||
|
|
||||||
|
self.graph_params.attn_params[4] = []
|
||||||
|
self.graph_params.attn_params[4].append(
|
||||||
|
(q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale,
|
||||||
|
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
|
||||||
|
out, lse, 2, 0, 0))
|
||||||
|
|
||||||
|
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
|
||||||
|
|
||||||
|
_mock_graph_task_end.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user