[UT]add pcp dcp ut (#4949)

### What this PR does / why we need it?
Adding UT for DCP/PCP

-vLLM version: v0.12.0
-vLLM main:
ad32e3e19c

Signed-off-by: zengran <zengran2@huawei.com>
This commit is contained in:
zengzengran
2025-12-15 18:41:38 +08:00
committed by GitHub
parent 5fae65f3a8
commit 6029bea480
3 changed files with 768 additions and 8 deletions

View File

@@ -0,0 +1,321 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
class TestAscendAttentionCPImpl(TestBase):
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 2
mock_pcp.rank_in_group = 0
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 2
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"
self.layer_no_quant = MagicMock(
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0
self.impl = AscendAttentionCPImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
def test_init(self):
self.assertEqual(self.impl.pcp_size, 2)
self.assertEqual(self.impl.pcp_rank, 0)
self.assertEqual(self.impl.dcp_size, 2)
self.assertEqual(self.impl.dcp_rank, 0)
def test_forward_prefill_cp(self):
query = torch.randn(2, 4, 128)
key = torch.randn(4, 1, 128)
value = torch.randn(4, 1, 128)
def mock_attention_with_nomask_and_mask(q, k_mask, **kwargs):
mock_output = torch.randn_like(q)
mock_lse = torch.randn_like(k_mask)
return mock_output, mock_lse
self.impl._attention_with_nomask_and_mask = MagicMock()
self.impl._attention_with_nomask_and_mask.side_effect = mock_attention_with_nomask_and_mask
attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_metadata.q_head_idx = torch.tensor([0])
attn_metadata.prefill.pcp_metadata.q_tail_idx = torch.tensor([1])
attn_metadata.prefill.pcp_metadata.q_full_idx = torch.tensor([0, 1])
attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor(
[0])
output, attn_lse = self.impl._forward_prefill_cp(
query, key, value, attn_metadata)
self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP')
@patch("torch_npu.npu_fused_infer_attention_score")
@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
mock_all_to_all_single, mock_all_gather,
mock_npu_fused_infer_attention_score,
mock_dcp, mock_get_dcp_group):
def mock_dcp_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
dcp_group.all_gather = mock_dcp_all_gather_func
mock_get_dcp_group.return_value = dcp_group
query = torch.randn(2, 4, 128)
self.impl.key_cache = torch.randn(100, 128, 1, 128)
self.impl.value_cache = torch.randn(100, 128, 1, 128)
def mock_npu_attention_update(attn_out_lse_list):
mock_output = torch.randn(attn_out_lse_list[0].shape[0],
attn_out_lse_list[0].shape[1],
attn_out_lse_list[0].shape[2] - 1)
return mock_output
self.impl._npu_attention_update = MagicMock()
self.impl._npu_attention_update.side_effect = mock_npu_attention_update
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
mock_all_gather.side_effect = mock_all_gather_func
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
**common_kwargs):
mock_output = torch.randn_like(query)
mock_lse = torch.randn(query.shape[0], query.shape[1], 1)
return mock_output, mock_lse
mock_npu_fused_infer_attention_score.side_effect = mock_npu_fused_infer_attention_score_func
attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP')
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP')
def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group,
mock_pcp, mock_get_pcp_group):
query = torch.randn(2, 4, 128)
def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.all_gather = mock_all_gather_func
mock_get_dcp_group.return_value = dcp_group
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = pcp_group
attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.chunked_context = MagicMock()
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk = torch.tensor(
[1, 2, 3, 0])
output = self.impl._prefill_query_all_gather(attn_metadata, query)
self.assertEqual(output.shape[0], 4)
self.assertEqual(output.shape[1], 8)
self.assertEqual(output.shape[2], 128)
@patch('torch.ops.npu.npu_fused_infer_attention_score')
def test_compute_prefill_context(self, mock_npu_attention):
block_num = 100
block_size = 128
kv_num_heads = 1
head_size = 128
kv_cache = (torch.randn(block_num, block_size, kv_num_heads,
head_size),
torch.randn(block_num, block_size, kv_num_heads,
head_size))
batch_size = 1024
self.impl.head_size = head_size
self.impl.num_heads = 4
num_heads = self.impl.num_heads * self.impl.dcp_size
query = torch.randn(batch_size, num_heads, head_size)
attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.chunked_context = MagicMock()
attn_metadata.prefill.chunked_context.local_context_lens_allranks = torch.tensor(
[[[256, 256], [256, 256]]])
attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint(
0, 2, (1024, ), dtype=torch.bool)
def mock_load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank, query,
total_toks):
return torch.randn(total_toks, kv_num_heads,
head_size), torch.randn(total_toks,
kv_num_heads, head_size)
self.impl._load_kv_for_chunk = MagicMock()
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
head_size), torch.randn(
batch_size,
num_heads, 1)
result_output, result_lse = self.impl._compute_prefill_context(
query, kv_cache, attn_metadata)
self.assertEqual(result_output.shape[0], batch_size)
self.assertEqual(result_output.shape[1], self.impl.num_heads)
self.assertEqual(result_output.shape[2], head_size)
self.assertEqual(result_lse.shape[0], batch_size)
self.assertEqual(result_lse.shape[1], self.impl.num_heads)
self.assertEqual(result_lse.shape[2], 1)
@patch('torch_npu.atb.npu_paged_cache_load')
def test_load_kv_for_chunk(self, mock_npu_paged_cache_load):
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))
query = torch.randn(4, 8, 128)
total_toks = 256
local_chunked_kv_lens_rank = torch.randn(total_toks)
attn_metadata = MagicMock()
key, value = self.impl._load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank,
query, total_toks)
self.assertEqual(key.shape[0], total_toks)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)
self.assertEqual(value.shape[0], total_toks)
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP')
@patch('torch_npu._npu_reshape_and_cache')
def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp,
mock_get_pcp_group):
num_tokens = 4
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
self.impl.head_size = head_size
kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))
attn_metadata = MagicMock()
attn_metadata.num_decode_tokens = 1
attn_metadata.num_decodes = 1
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = torch.randn(2)
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0])
key = torch.randn(num_tokens, num_heads, head_size)
value = torch.randn(num_tokens, num_heads, head_size)
def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = pcp_group
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
attn_metadata)
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)
self.assertEqual(value.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)

403
tests/ut/attention/test_mla_cp.py Executable file
View File

@@ -0,0 +1,403 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()
mock_dcp.world_size = 2
mock_dcp.rank_in_group = MagicMock()
mock_dcp.device_group = MagicMock()
mock_pcp.world_size = 2
mock_pcp.rank_in_group = MagicMock()
mock_pcp.device_group = MagicMock()
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
model_config.dtype = torch.float16
vllm_config.model_config = model_config
get_current_vllm_config.return_value = vllm_config
vllm_config.additional_config = {"refresh": True}
init_ascend_config(vllm_config)
num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"
kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"q_lora_rank": 64,
"q_proj": MagicMock(),
"q_b_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"fused_qkv_a_proj": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
"rotary_emb": MagicMock(),
}
self.impl = AscendMlaCPImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)
def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.pcp_size, 2)
self.assertEqual(self.impl.dcp_size, 2)
@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
def test_mla_preprocess_dcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_dcp_group):
self.impl.num_kv_heads = 1
self.impl.num_heads = 16
self.impl.qk_rope_head_dim = 64
self.impl.kv_lora_rank = 512
self.impl.q_lora_rank = 1536
self.impl.dcp_size = 2
self.impl.pcp_size = 2
block_num = 10
block_size = 128
batch_size = 2
hidden_size = 1024
hidden_states = torch.randn(batch_size, hidden_size)
kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.kv_lora_rank)
kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.qk_rope_head_dim)
kv_cache = (kv_cache0, kv_cache1)
mock_dcp_group = MagicMock()
def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
mock_dcp_group.all_gather = mock_all_gather_func
mock_get_dcp_group.return_value = mock_dcp_group
attn_metadata = MagicMock()
attn_metadata.num_decodes = 2
attn_metadata.num_prefills = 0
attn_metadata.num_prefill_tokens = 0
attn_metadata.num_decode_tokens = 2
attn_metadata.num_actual_tokens = 2
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.decode.cos = torch.randn(2, 64)
attn_metadata.decode.sin = torch.randn(2, 64)
self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(
attn_metadata.num_actual_tokens, self.impl.q_lora_rank)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [
torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
]
self.impl.fused_qkv_a_proj = MagicMock()
self.impl.fused_qkv_a_proj.return_value = [
torch.randn(
attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim +
self.impl.kv_lora_rank + self.impl.q_lora_rank)
]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.qk_rope_head_dim)
]
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer",
hidden_states,
kv_cache,
attn_metadata,
need_gather_q_kv=False)
self.assertIsNotNone(decode_res)
self.assertIsNone(prefill_res)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
def test_mla_preprocess_pcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_pcp_group,
mock_npu_reshape_and_cache):
self.impl.num_kv_heads = 1
self.impl.num_heads = 16
self.impl.qk_rope_head_dim = 64
self.impl.kv_lora_rank = 512
self.impl.q_lora_rank = 1536
self.impl.dcp_size = 2
self.impl.pcp_size = 2
block_num = 10
block_size = 128
batch_size = 2
hidden_size = 1024
hidden_states = torch.randn(batch_size, hidden_size)
kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.kv_lora_rank)
kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.qk_rope_head_dim)
kv_cache = (kv_cache0, kv_cache1)
mock_pcp_group = MagicMock()
def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)
mock_pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = mock_pcp_group
attn_metadata = MagicMock()
attn_metadata.num_decodes = 0
attn_metadata.num_prefills = 2
attn_metadata.num_prefill_tokens = 2
attn_metadata.num_decode_tokens = 0
attn_metadata.num_actual_tokens = 2
attn_metadata.num_actual_tokens_pcp_padded = 4
attn_metadata.prefill.pcp_metadata = MagicMock()
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.arange(
4)
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.prefill.cos = torch.randn(2, 64)
attn_metadata.prefill.sin = torch.randn(2, 64)
self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(
attn_metadata.num_actual_tokens, self.impl.q_lora_rank)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [
torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
]
self.impl.fused_qkv_a_proj = MagicMock()
self.impl.fused_qkv_a_proj.return_value = [
torch.randn(
attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim +
self.impl.kv_lora_rank + self.impl.q_lora_rank)
]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.qk_rope_head_dim)
]
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
self.impl.kv_a_layernorm = MagicMock()
self.impl.kv_a_layernorm.return_value = torch.randn(
attn_metadata.num_prefill_tokens, self.impl.num_kv_heads,
self.impl.kv_lora_rank)
self.impl.q_proj = MagicMock()
self.impl.q_proj.return_value = [
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
self.impl.qk_head_dim)
]
self.impl.kv_b_proj = MagicMock()
self.impl.kv_b_proj.return_value = [
torch.randn(attn_metadata.num_prefill_tokens * self.impl.pcp_size,
self.impl.num_heads,
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl.exec_kv_prefill = MagicMock()
self.impl.exec_kv_prefill.return_value = [
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim),
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
self.impl.kv_lora_rank)
]
decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer",
hidden_states,
kv_cache,
attn_metadata,
need_gather_q_kv=False)
self.assertIsNone(decode_res)
self.assertIsNotNone(prefill_res)
@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
def test_process_attn_out_lse(self, mock_all_to_all_single,
mock_all_gather):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
B = 2
N = self.impl.num_heads
self.impl.kv_lora_rank = 512
attn_output = torch.randn(B, N, self.impl.kv_lora_rank)
softmax_lse = torch.randn(B, N, 1)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
mock_all_gather.side_effect = mock_all_gather_func
decode_metadata = MagicMock()
decode_metadata.actual_seq_lengths_q = MagicMock()
decode_metadata.seq_lens_list = MagicMock()
decode_metadata.batch_seq_mask = torch.tensor([True, False],
dtype=torch.bool)
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
decode_metadata)
self.assertEqual(result[0].shape[0], B)
self.assertEqual(result[0].shape[1], N / self.impl.dcp_size)
self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1)
@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch('torch_npu.npu_attention_update')
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_get_forward_context,
mock_all_to_all_single, mock_all_gather):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
self.impl.num_kv_heads = 1
self.impl.num_heads = 16
self.impl.kv_lora_rank = 64
self.impl.qk_nope_head_dim = 64
self.impl.spec_token_num = 1
B = 2
N = self.impl.num_heads * self.impl.dcp_size
BS = 128
NB = 100
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock()
attn_metadata.decode.batch_seq_mask = torch.tensor([False, False],
dtype=torch.bool)
self.impl.enable_kv_nz = True
mock_npu_attention_update.return_value = (torch.randn(
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
mock_npu_multi_head_latent_attention.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank),
torch.randn(B, N, 1)
]
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
mock_all_gather.side_effect = mock_all_gather_func
self.impl._v_up_proj = MagicMock()
self.impl._v_up_proj.return_value = torch.randn(
B, self.impl.v_head_dim)
result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe,
BS, attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], self.impl.v_head_dim)

52
tests/ut/attention/test_mla_v1.py Normal file → Executable file
View File

@@ -75,6 +75,12 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
padded_chunk_seq_lens_npu = torch.tensor([2, 2])
padded_local_chunk_seq_lens = [[2], [2]]
local_context_lens_allranks = [[1, 1], [1, 1]]
padded_local_cu_seq_lens = torch.tensor([0, 2, 4])
cu_seq_lens_lst = [[0, 2], [2, 4]]
chunk_size = 2
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
@@ -83,7 +89,13 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
chunk_seq_lens_npu=chunk_seq_lens,
padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens,
local_context_lens_allranks=local_context_lens_allranks,
padded_local_cu_seq_lens=padded_local_cu_seq_lens,
cu_seq_lens_lst=cu_seq_lens_lst,
chunk_size=chunk_size)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -106,6 +118,17 @@ class TestAscendMLAPrefillMetadata(TestBase):
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu,
padded_chunk_seq_lens_npu)
self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens,
padded_local_chunk_seq_lens)
self.assertEqual(metadata.chunked_context.local_context_lens_allranks,
local_context_lens_allranks)
self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens,
padded_local_cu_seq_lens)
self.assertEqual(metadata.chunked_context.cu_seq_lens_lst,
cu_seq_lens_lst)
self.assertEqual(metadata.chunked_context.chunk_size, chunk_size)
class TestAscendMLADecodeMetadata(TestBase):
@@ -117,10 +140,17 @@ class TestAscendMLADecodeMetadata(TestBase):
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
cp_seq_len = torch.tensor([2, 3])
batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]])
metadata = AscendMLADecodeMetadata(input_positions, block_table,
seq_lens, max_seq_lens,
seq_lens_list, attn_mask)
metadata = AscendMLADecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
max_seq_lens=max_seq_lens,
seq_lens_list=seq_lens_list,
attn_mask=attn_mask,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
@@ -128,6 +158,8 @@ class TestAscendMLADecodeMetadata(TestBase):
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
self.assertIs(metadata.cp_seq_len, cp_seq_len)
self.assertIs(metadata.batch_seq_mask, batch_seq_mask)
class TestAscendMLAMetadata(TestBase):
@@ -200,17 +232,19 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
mock_dcp.world_size = 1
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 1
mock_pcp.world_size = 2
mock_pcp.rank_in_group = 0
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.world_size = 2
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
@@ -227,6 +261,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.dcp_size, mock_dcp.world_size)
self.assertEqual(builder.pcp_size, mock_pcp.world_size)
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',