Files
xc-llm-ascend/tests/ut/attention/test_attention_cp.py
zengzengran 6029bea480 [UT]add pcp dcp ut (#4949)
### What this PR does / why we need it?
Adding UT for DCP/PCP

-vLLM version: v0.12.0
-vLLM main:
ad32e3e19c

Signed-off-by: zengran <zengran2@huawei.com>
2025-12-15 18:41:38 +08:00

322 lines
13 KiB
Python

from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
class TestAscendAttentionCPImpl(TestBase):
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 2
mock_pcp.rank_in_group = 0
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 2
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"
self.layer_no_quant = MagicMock(
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0
self.impl = AscendAttentionCPImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
def test_init(self):
self.assertEqual(self.impl.pcp_size, 2)
self.assertEqual(self.impl.pcp_rank, 0)
self.assertEqual(self.impl.dcp_size, 2)
self.assertEqual(self.impl.dcp_rank, 0)
def test_forward_prefill_cp(self):
query = torch.randn(2, 4, 128)
key = torch.randn(4, 1, 128)
value = torch.randn(4, 1, 128)
def mock_attention_with_nomask_and_mask(q, k_mask, **kwargs):
mock_output = torch.randn_like(q)
mock_lse = torch.randn_like(k_mask)
return mock_output, mock_lse
self.impl._attention_with_nomask_and_mask = MagicMock()
self.impl._attention_with_nomask_and_mask.side_effect = mock_attention_with_nomask_and_mask
attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_metadata.q_head_idx = torch.tensor([0])
attn_metadata.prefill.pcp_metadata.q_tail_idx = torch.tensor([1])
attn_metadata.prefill.pcp_metadata.q_full_idx = torch.tensor([0, 1])
attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor(
[0])
output, attn_lse = self.impl._forward_prefill_cp(
query, key, value, attn_metadata)
self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP')
@patch("torch_npu.npu_fused_infer_attention_score")
@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
mock_all_to_all_single, mock_all_gather,
mock_npu_fused_infer_attention_score,
mock_dcp, mock_get_dcp_group):
def mock_dcp_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
dcp_group.all_gather = mock_dcp_all_gather_func
mock_get_dcp_group.return_value = dcp_group
query = torch.randn(2, 4, 128)
self.impl.key_cache = torch.randn(100, 128, 1, 128)
self.impl.value_cache = torch.randn(100, 128, 1, 128)
def mock_npu_attention_update(attn_out_lse_list):
mock_output = torch.randn(attn_out_lse_list[0].shape[0],
attn_out_lse_list[0].shape[1],
attn_out_lse_list[0].shape[2] - 1)
return mock_output
self.impl._npu_attention_update = MagicMock()
self.impl._npu_attention_update.side_effect = mock_npu_attention_update
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
mock_all_gather.side_effect = mock_all_gather_func
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
**common_kwargs):
mock_output = torch.randn_like(query)
mock_lse = torch.randn(query.shape[0], query.shape[1], 1)
return mock_output, mock_lse
mock_npu_fused_infer_attention_score.side_effect = mock_npu_fused_infer_attention_score_func
attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP')
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP')
def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group,
mock_pcp, mock_get_pcp_group):
query = torch.randn(2, 4, 128)
def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.all_gather = mock_all_gather_func
mock_get_dcp_group.return_value = dcp_group
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = pcp_group
attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.chunked_context = MagicMock()
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk = torch.tensor(
[1, 2, 3, 0])
output = self.impl._prefill_query_all_gather(attn_metadata, query)
self.assertEqual(output.shape[0], 4)
self.assertEqual(output.shape[1], 8)
self.assertEqual(output.shape[2], 128)
@patch('torch.ops.npu.npu_fused_infer_attention_score')
def test_compute_prefill_context(self, mock_npu_attention):
block_num = 100
block_size = 128
kv_num_heads = 1
head_size = 128
kv_cache = (torch.randn(block_num, block_size, kv_num_heads,
head_size),
torch.randn(block_num, block_size, kv_num_heads,
head_size))
batch_size = 1024
self.impl.head_size = head_size
self.impl.num_heads = 4
num_heads = self.impl.num_heads * self.impl.dcp_size
query = torch.randn(batch_size, num_heads, head_size)
attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.chunked_context = MagicMock()
attn_metadata.prefill.chunked_context.local_context_lens_allranks = torch.tensor(
[[[256, 256], [256, 256]]])
attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint(
0, 2, (1024, ), dtype=torch.bool)
def mock_load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank, query,
total_toks):
return torch.randn(total_toks, kv_num_heads,
head_size), torch.randn(total_toks,
kv_num_heads, head_size)
self.impl._load_kv_for_chunk = MagicMock()
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
head_size), torch.randn(
batch_size,
num_heads, 1)
result_output, result_lse = self.impl._compute_prefill_context(
query, kv_cache, attn_metadata)
self.assertEqual(result_output.shape[0], batch_size)
self.assertEqual(result_output.shape[1], self.impl.num_heads)
self.assertEqual(result_output.shape[2], head_size)
self.assertEqual(result_lse.shape[0], batch_size)
self.assertEqual(result_lse.shape[1], self.impl.num_heads)
self.assertEqual(result_lse.shape[2], 1)
@patch('torch_npu.atb.npu_paged_cache_load')
def test_load_kv_for_chunk(self, mock_npu_paged_cache_load):
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))
query = torch.randn(4, 8, 128)
total_toks = 256
local_chunked_kv_lens_rank = torch.randn(total_toks)
attn_metadata = MagicMock()
key, value = self.impl._load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank,
query, total_toks)
self.assertEqual(key.shape[0], total_toks)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)
self.assertEqual(value.shape[0], total_toks)
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP')
@patch('torch_npu._npu_reshape_and_cache')
def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp,
mock_get_pcp_group):
num_tokens = 4
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
self.impl.head_size = head_size
kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))
attn_metadata = MagicMock()
attn_metadata.num_decode_tokens = 1
attn_metadata.num_decodes = 1
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = torch.randn(2)
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0])
key = torch.randn(num_tokens, num_heads, head_size)
value = torch.randn(num_tokens, num_heads, head_size)
def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = pcp_group
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
attn_metadata)
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)
self.assertEqual(value.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)