diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index c024fcea..a8142fc6 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -15,15 +15,21 @@ from unittest.mock import MagicMock, Mock, patch +import numpy as np import torch from vllm.compilation.cuda_graph import CUDAGraphOptions from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, ForwardContext from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import (AscendMetadata, + AscendMetadataForDecode) +from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, + AscendMLAMetadata) from vllm_ascend.compilation.acl_graph import ( - ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params, - update_mtp_graph_params_workspaces) + ACLGraphEntry, ACLGraphWrapper, get_graph_params, get_mtp_graph_params, + set_graph_params, set_mtp_graph_params, update_attn_dcp_pcp_params, + update_mla_attn_dcp_pcp_params, update_mtp_graph_params_workspaces) class TestACLGraphEntry(TestBase): @@ -726,3 +732,116 @@ class TestMTPGraphParams(TestBase): def test_get_mtp_graph_params(self, mtp_graph_params_mock): graph_params = get_mtp_graph_params() self.assertIs(mtp_graph_params_mock, graph_params) + + +class TestPCPDCPGraphParams(TestBase): + + def setUp(self): + self.update_stream = MagicMock(name="FakeStream") + graph_params = get_graph_params() + if graph_params is None: + set_graph_params(set([4])) + self.graph_params = get_graph_params() + else: + self.graph_params = graph_params + mock_event = torch.npu.ExternalEvent() + mock_event.record = MagicMock() + self.graph_params.events[4] = [] + self.graph_params.handles[4] = [] + self.graph_params.events[4].append(mock_event) + self.graph_params.handles[4].append(MagicMock()) + + @patch('torch.npu.graph_task_update_end', ) + @patch('torch.npu.graph_task_update_begin', MagicMock()) + @patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock()) + def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): + input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + block_table = torch.zeros(2, 5, dtype=torch.long) + seq_lens = torch.tensor([4, 4]) + cp_seq_len = torch.tensor([2, 2]) + max_seq_lens = 4 + seq_lens_list = [4, 4] + slot_mapping = torch.zeros(8, dtype=torch.long) + query_start_loc = torch.tensor([0, 4]) + block_tables = torch.zeros(2, 5, dtype=torch.long) + + decode = AscendMLADecodeMetadata(input_positions, + block_table, + seq_lens, + max_seq_lens, + seq_lens_list, + cp_seq_len=cp_seq_len) + metadata = AscendMLAMetadata(8, + 8, + slot_mapping, + query_start_loc, + seq_lens, + block_tables, + 4, + 4, + 0, + decode=decode) + forward_context = MagicMock() + forward_context.attn_metadata = {"attn_layer_0": metadata} + forward_context.is_mtp_model = False + + num_heads = 256 + scale = 0.1 + num_kv_heads = 8 + qk_head_dim = 96 + qk_rope_head_dim = 32 + qk_nope_head_dim = 64 + query = torch.randn(4, num_heads, qk_head_dim) + q_pe = query[..., qk_nope_head_dim:] + q_nope = query[..., :qk_nope_head_dim] + k_nope = torch.randn(4, num_heads, qk_nope_head_dim) + k_pe = torch.randn(4, num_heads, qk_rope_head_dim) + out = torch.randn(2, 16, 128) + lse = torch.randn(2, 16, 8) + self.graph_params.attn_params[4] = [] + self.graph_params.attn_params[4].append( + (q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads, + scale, num_kv_heads, out, lse)) + + update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4) + + _mock_graph_task_end.assert_called_once() + + @patch('torch.npu.graph_task_update_end', ) + @patch('torch.npu.graph_task_update_begin', MagicMock()) + @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) + def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end): + block_table = torch.zeros(2, 5, dtype=torch.long) + num_heads = 256 + scale = 0.1 + num_kv_heads = 8 + qk_head_dim = 96 + qk_nope_head_dim = 64 + query = torch.randn(4, num_heads, qk_head_dim) + q_nope = query[..., :qk_nope_head_dim] + k_nope = torch.randn(4, num_heads, qk_nope_head_dim) + actual_seq_lengths_kv = [1, 1] + actual_seq_lengths_q = np.array([1, 1]) + out = torch.randn(2, 16, 128) + lse = torch.randn(2, 16, 8) + + num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]], + [[1, 1], [1, 1]]]) + decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp) + metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1], + actual_seq_lengths_q=actual_seq_lengths_q, + num_decode_tokens=1, + decode_meta=decode) + forward_context = MagicMock() + forward_context.attn_metadata = {"attn_layer_0": metadata} + forward_context.is_mtp_model = False + + self.graph_params.attn_params[4] = [] + self.graph_params.attn_params[4].append( + (q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale, + block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q, + out, lse, 2, 0, 0)) + + update_attn_dcp_pcp_params(self.update_stream, forward_context, 4) + + _mock_graph_task_end.assert_called_once()