## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
1014 lines
48 KiB
Python
Executable File
1014 lines
48 KiB
Python
Executable File
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from vllm.distributed.parallel_state import GroupCoordinator
|
|
|
|
from tests.ut.attention.utils import patch_distributed_groups
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ascend_config import init_ascend_config
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
|
CPChunkedContextMetadata, _npu_attention_update, _process_attn_out_lse)
|
|
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
|
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
|
|
|
|
|
def get_pcp_split_info(pcp_rank, pcp_size, seq_lens):
|
|
q_head_idx, q_tail_idx = [], []
|
|
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
|
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
|
chunk_seqlens = []
|
|
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
|
q_req_offset = 0
|
|
kv_req_offset = 0
|
|
q_head_chunk_id = pcp_rank
|
|
q_tail_chunk_id = pcp_size * 2 - 1 - pcp_rank
|
|
for i, seq_len in enumerate(seq_lens):
|
|
chunk_len = seq_len // 2
|
|
chunk_seqlens.append(chunk_len)
|
|
q_head_idx.extend(list(range(q_req_offset, q_req_offset + chunk_len)))
|
|
kv_with_q_head_nomask_idx.extend(
|
|
list(
|
|
range(kv_req_offset,
|
|
kv_req_offset + chunk_len * q_head_chunk_id)))
|
|
kv_with_q_head_mask_idx.extend(
|
|
list(
|
|
range(kv_req_offset + chunk_len * q_head_chunk_id,
|
|
kv_req_offset + chunk_len * (q_head_chunk_id + 1))))
|
|
kv_with_q_head_nomask_seqlens.append(chunk_len * q_head_chunk_id)
|
|
|
|
q_tail_idx.extend(
|
|
list(range(q_req_offset + chunk_len,
|
|
q_req_offset + chunk_len * 2)))
|
|
kv_with_q_tail_nomask_idx.extend(
|
|
list(
|
|
range(kv_req_offset,
|
|
kv_req_offset + chunk_len * q_tail_chunk_id)))
|
|
kv_with_q_tail_mask_idx.extend(
|
|
list(
|
|
range(kv_req_offset + chunk_len * q_tail_chunk_id,
|
|
kv_req_offset + chunk_len * (q_tail_chunk_id + 1))))
|
|
kv_with_q_tail_nomask_seqlens.append(chunk_len * q_tail_chunk_id)
|
|
|
|
q_req_offset += seq_len
|
|
kv_req_offset += seq_len * pcp_size
|
|
return (
|
|
torch.tensor(q_head_idx),
|
|
torch.tensor(q_tail_idx),
|
|
torch.tensor(kv_with_q_head_nomask_idx),
|
|
torch.tensor(kv_with_q_head_mask_idx),
|
|
torch.tensor(kv_with_q_tail_nomask_idx),
|
|
torch.tensor(kv_with_q_tail_mask_idx),
|
|
chunk_seqlens,
|
|
kv_with_q_head_nomask_seqlens,
|
|
kv_with_q_tail_nomask_seqlens,
|
|
)
|
|
|
|
|
|
def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
|
block_size, num_computed_tokens_cpu, num_reqs,
|
|
chunked_prefill_workspace_size,
|
|
num_computed_tokens_of_pcp_dcp, cp_local_block_size):
|
|
reqs_start = num_decodes
|
|
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
|
max_context_len_cpu = context_lens_cpu.max().item()
|
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
|
max_context_chunk = (chunked_prefill_workspace_size //
|
|
num_prefills_with_context_cpu)
|
|
max_context_chunk = max_context_chunk // block_size * block_size
|
|
|
|
assert max_context_chunk > 0
|
|
num_chunks = (max_context_len_cpu + max_context_chunk -
|
|
1) // max_context_chunk
|
|
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
|
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
|
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
|
chunk_starts + max_context_chunk)
|
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
|
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
|
num_prefills + 1,
|
|
dtype=torch.int32)
|
|
torch.cumsum(chunk_seq_lens,
|
|
dim=1,
|
|
out=cu_seq_lens_cpu[:, 1:],
|
|
dtype=torch.int32)
|
|
|
|
def cdiv(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
if dcp_size * pcp_size > 1:
|
|
if num_computed_tokens_of_pcp_dcp is not None:
|
|
local_context_lens_allranks = torch.tensor(
|
|
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]).reshape(
|
|
-1, dcp_size * pcp_size)
|
|
# Note(qcs): The max local context lengths
|
|
# padded to `cp_local_block_size`.
|
|
padded_local_context_lens_cpu = (cdiv(
|
|
context_lens_cpu,
|
|
cp_local_block_size * pcp_size * dcp_size,
|
|
) * cp_local_block_size)
|
|
padded_local_max_context_chunk_across_ranks = (cdiv(
|
|
max_context_chunk,
|
|
cp_local_block_size * pcp_size * dcp_size,
|
|
) * cp_local_block_size)
|
|
local_chunk_starts = (
|
|
torch.arange(num_chunks, dtype=torch.int32).unsqueeze(1).expand(
|
|
-1, num_prefills) *
|
|
padded_local_max_context_chunk_across_ranks)
|
|
local_chunk_ends = torch.min(
|
|
padded_local_context_lens_cpu.unsqueeze(0),
|
|
local_chunk_starts + padded_local_max_context_chunk_across_ranks,
|
|
)
|
|
padded_local_chunk_seq_lens = (local_chunk_ends -
|
|
local_chunk_starts).clamp(min=0)
|
|
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(num_chunks,
|
|
num_prefills + 1,
|
|
dtype=torch.int32)
|
|
torch.cumsum(
|
|
padded_local_chunk_seq_lens,
|
|
dim=1,
|
|
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
|
dtype=torch.int32,
|
|
)
|
|
chunked_context_metadata = CPChunkedContextMetadata(
|
|
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
|
starts=local_chunk_starts.to(non_blocking=True),
|
|
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
|
chunk_seq_lens=chunk_seq_lens,
|
|
chunk_seq_lens_npu=chunk_seq_lens,
|
|
workspace=None,
|
|
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens,
|
|
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
|
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
|
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
|
|
non_blocking=True),
|
|
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
|
chunk_size=padded_local_max_context_chunk_across_ranks,
|
|
)
|
|
else:
|
|
chunked_context_metadata = (ChunkedContextMetadata(
|
|
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
|
starts=chunk_starts.to(non_blocking=True),
|
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
|
chunk_seq_lens=chunk_seq_lens,
|
|
chunk_seq_lens_npu=chunk_seq_lens,
|
|
workspace=None,
|
|
))
|
|
return chunked_context_metadata
|
|
|
|
|
|
class TestAscendMLAImpl(TestBase):
|
|
|
|
@patch('vllm.distributed.parallel_state._TP',
|
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
|
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
|
return_value=2)
|
|
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
|
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
|
mock_tp):
|
|
mock_tp.world_size = 2
|
|
mock_tp.rank_in_group = MagicMock()
|
|
mock_tp.device_group = MagicMock()
|
|
vllm_config = MagicMock()
|
|
speculative_config = MagicMock()
|
|
model_config = MagicMock()
|
|
speculative_config.num_speculative_tokens = 4
|
|
vllm_config.speculative_config = speculative_config
|
|
model_config.dtype = torch.float16
|
|
vllm_config.model_config = model_config
|
|
get_current_vllm_config.return_value = vllm_config
|
|
vllm_config.additional_config = {"refresh": True}
|
|
init_ascend_config(vllm_config)
|
|
|
|
num_heads = 256
|
|
head_size = 1024
|
|
scale = 0.1
|
|
num_kv_heads = 8
|
|
kv_cache_dtype = "auto"
|
|
|
|
kv_a_layernorm = MagicMock()
|
|
kv_a_layernorm.weight = torch.randn(96)
|
|
kv_a_layernorm.variance_epsilon = 1e-6
|
|
kwargs = {
|
|
"kv_lora_rank": 32,
|
|
"qk_nope_head_dim": 64,
|
|
"qk_rope_head_dim": 32,
|
|
"qk_head_dim": 96,
|
|
"v_head_dim": 128,
|
|
"q_lora_rank": 64,
|
|
"q_proj": MagicMock(),
|
|
"q_b_proj": MagicMock(),
|
|
"kv_b_proj": MagicMock(),
|
|
"o_proj": MagicMock(),
|
|
"kv_a_proj_with_mqa": MagicMock(),
|
|
"fused_qkv_a_proj": MagicMock(),
|
|
"kv_a_layernorm": kv_a_layernorm,
|
|
"rotary_emb": MagicMock(),
|
|
}
|
|
|
|
self.impl = AscendMlaCPImpl(num_heads=num_heads,
|
|
head_size=head_size,
|
|
scale=scale,
|
|
num_kv_heads=num_kv_heads,
|
|
alibi_slopes=None,
|
|
sliding_window=None,
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
blocksparse_params=None,
|
|
logits_soft_cap=None,
|
|
attn_type=None,
|
|
kv_sharing_target_layer_name=None,
|
|
**kwargs)
|
|
|
|
def test_init(self):
|
|
self.assertEqual(self.impl.num_heads, 256)
|
|
self.assertEqual(self.impl.head_size, 1024)
|
|
self.assertEqual(self.impl.scale, 0.1)
|
|
self.assertEqual(self.impl.num_kv_heads, 8)
|
|
self.assertEqual(self.impl.kv_cache_dtype, "auto")
|
|
self.assertEqual(self.impl.kv_lora_rank, 32)
|
|
self.assertEqual(self.impl.qk_nope_head_dim, 64)
|
|
self.assertEqual(self.impl.qk_rope_head_dim, 32)
|
|
self.assertEqual(self.impl.qk_head_dim, 96)
|
|
self.assertEqual(self.impl.v_head_dim, 128)
|
|
self.assertIsNotNone(self.impl.q_proj)
|
|
self.assertIsNotNone(self.impl.kv_b_proj)
|
|
self.assertIsNotNone(self.impl.o_proj)
|
|
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
|
|
self.assertIsNotNone(self.impl.kv_a_layernorm)
|
|
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
|
self.assertEqual(self.impl.pcp_size, 2)
|
|
self.assertEqual(self.impl.dcp_size, 2)
|
|
|
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
|
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_mla_preprocess_dcp(self, magic_npu_fetch,
|
|
mock_maybe_all_gather_and_maybe_unpad):
|
|
|
|
self.impl.num_kv_heads = 1
|
|
self.impl.num_heads = 16
|
|
self.impl.qk_rope_head_dim = 64
|
|
self.impl.kv_lora_rank = 512
|
|
self.impl.q_lora_rank = 1536
|
|
self.impl.dcp_size = 2
|
|
self.impl.pcp_size = 2
|
|
block_num = 10
|
|
block_size = 128
|
|
batch_size = 2
|
|
hidden_size = 1024
|
|
hidden_states = torch.randn(batch_size, hidden_size)
|
|
|
|
kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
|
|
self.impl.kv_lora_rank)
|
|
kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
|
|
self.impl.qk_rope_head_dim)
|
|
kv_cache = (kv_cache0, kv_cache1)
|
|
|
|
attn_metadata = MagicMock()
|
|
attn_metadata.num_decodes = 2
|
|
attn_metadata.num_prefills = 0
|
|
attn_metadata.num_prefill_tokens = 0
|
|
attn_metadata.num_decode_tokens = 2
|
|
attn_metadata.num_actual_tokens = 2
|
|
attn_metadata.slot_mapping = torch.arange(4)
|
|
attn_metadata.decode.cos = torch.randn(2, 64)
|
|
attn_metadata.decode.sin = torch.randn(2, 64)
|
|
|
|
self.impl.q_a_layernorm = MagicMock()
|
|
self.impl.q_a_layernorm.return_value = torch.randn(
|
|
attn_metadata.num_actual_tokens, self.impl.q_lora_rank)
|
|
self.impl.kv_a_proj_with_mqa = MagicMock()
|
|
self.impl.kv_a_proj_with_mqa.return_value = [
|
|
torch.randn(batch_size, self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
|
|
]
|
|
self.impl.fused_qkv_a_proj = MagicMock()
|
|
self.impl.fused_qkv_a_proj.return_value = [
|
|
torch.randn(
|
|
attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim +
|
|
self.impl.kv_lora_rank + self.impl.q_lora_rank)
|
|
]
|
|
|
|
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
|
|
self.impl.exec_kv_decode = MagicMock()
|
|
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
|
|
|
|
self.impl._q_proj_and_k_up_proj = MagicMock()
|
|
self.impl._q_proj_and_k_up_proj.return_value = [
|
|
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
|
|
self.impl.kv_lora_rank),
|
|
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim)
|
|
]
|
|
|
|
magic_npu_fetch.return_value = MagicMock()
|
|
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
|
|
|
|
decode_res, prefill_res = self.impl._mla_preprocess(
|
|
"mock_layer",
|
|
hidden_states,
|
|
kv_cache,
|
|
attn_metadata,
|
|
need_gather_q_kv=False)
|
|
|
|
self.assertIsNotNone(decode_res)
|
|
self.assertIsNone(prefill_res)
|
|
|
|
@patch('torch_npu._npu_reshape_and_cache')
|
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
|
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_mla_preprocess_pcp(self, magic_npu_fetch,
|
|
mock_maybe_all_gather_and_maybe_unpad,
|
|
mock_npu_reshape_and_cache):
|
|
self.impl.num_kv_heads = 1
|
|
self.impl.num_heads = 16
|
|
self.impl.qk_rope_head_dim = 64
|
|
self.impl.kv_lora_rank = 512
|
|
self.impl.q_lora_rank = 1536
|
|
self.impl.dcp_size = 2
|
|
self.impl.pcp_size = 2
|
|
block_num = 10
|
|
block_size = 128
|
|
batch_size = 2
|
|
hidden_size = 1024
|
|
hidden_states = torch.randn(batch_size, hidden_size)
|
|
|
|
kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
|
|
self.impl.kv_lora_rank)
|
|
kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
|
|
self.impl.qk_rope_head_dim)
|
|
kv_cache = (kv_cache0, kv_cache1)
|
|
|
|
attn_metadata = MagicMock()
|
|
attn_metadata.num_decodes = 0
|
|
attn_metadata.num_prefills = 2
|
|
attn_metadata.num_prefill_tokens = 2
|
|
attn_metadata.num_decode_tokens = 0
|
|
attn_metadata.num_actual_tokens = 2
|
|
attn_metadata.num_actual_tokens_pcp_padded = 4
|
|
attn_metadata.prefill.pcp_metadata = MagicMock()
|
|
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.arange(
|
|
4)
|
|
attn_metadata.slot_mapping = torch.arange(4)
|
|
attn_metadata.prefill.cos = torch.randn(2, 64)
|
|
attn_metadata.prefill.sin = torch.randn(2, 64)
|
|
|
|
self.impl.q_a_layernorm = MagicMock()
|
|
self.impl.q_a_layernorm.return_value = torch.randn(
|
|
attn_metadata.num_actual_tokens, self.impl.q_lora_rank)
|
|
self.impl.kv_a_proj_with_mqa = MagicMock()
|
|
self.impl.kv_a_proj_with_mqa.return_value = [
|
|
torch.randn(batch_size, self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
|
|
]
|
|
self.impl.fused_qkv_a_proj = MagicMock()
|
|
self.impl.fused_qkv_a_proj.return_value = [
|
|
torch.randn(
|
|
attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim +
|
|
self.impl.kv_lora_rank + self.impl.q_lora_rank)
|
|
]
|
|
|
|
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
|
|
self.impl.exec_kv_decode = MagicMock()
|
|
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
|
|
|
|
self.impl._q_proj_and_k_up_proj = MagicMock()
|
|
self.impl._q_proj_and_k_up_proj.return_value = [
|
|
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
|
|
self.impl.kv_lora_rank),
|
|
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim)
|
|
]
|
|
|
|
magic_npu_fetch.return_value = MagicMock()
|
|
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
|
|
|
|
self.impl.kv_a_layernorm = MagicMock()
|
|
self.impl.kv_a_layernorm.return_value = torch.randn(
|
|
attn_metadata.num_prefill_tokens, self.impl.num_kv_heads,
|
|
self.impl.kv_lora_rank)
|
|
|
|
self.impl.q_proj = MagicMock()
|
|
self.impl.q_proj.return_value = [
|
|
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
|
|
self.impl.qk_head_dim)
|
|
]
|
|
self.impl.kv_b_proj = MagicMock()
|
|
self.impl.kv_b_proj.return_value = [
|
|
torch.randn(attn_metadata.num_prefill_tokens * self.impl.pcp_size,
|
|
self.impl.num_heads,
|
|
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
|
|
]
|
|
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
|
|
self.impl.exec_kv_decode = MagicMock()
|
|
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
|
|
self.impl.exec_kv_prefill = MagicMock()
|
|
self.impl.exec_kv_prefill.return_value = [
|
|
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim),
|
|
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
|
|
self.impl.kv_lora_rank)
|
|
]
|
|
|
|
decode_res, prefill_res = self.impl._mla_preprocess(
|
|
"mock_layer",
|
|
hidden_states,
|
|
kv_cache,
|
|
attn_metadata,
|
|
need_gather_q_kv=False)
|
|
self.assertIsNone(decode_res)
|
|
self.assertIsNotNone(prefill_res)
|
|
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_process_attn_out_lse(self):
|
|
self.impl.dcp_size = 2
|
|
self.impl.pcp_size = 2
|
|
|
|
B = 2
|
|
N = self.impl.num_heads
|
|
self.impl.kv_lora_rank = 512
|
|
|
|
attn_output = torch.randn(B, N, self.impl.kv_lora_rank)
|
|
softmax_lse = torch.randn(B, N, 1)
|
|
|
|
decode_metadata = MagicMock()
|
|
decode_metadata.actual_seq_lengths_q = MagicMock()
|
|
decode_metadata.seq_lens_list = MagicMock()
|
|
decode_metadata.batch_seq_mask = torch.tensor([True, False],
|
|
dtype=torch.bool)
|
|
|
|
result = _process_attn_out_lse(attn_output, softmax_lse,
|
|
decode_metadata.batch_seq_mask)
|
|
|
|
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
|
|
self.assertEqual(result.shape[1], N)
|
|
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
|
|
|
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
|
|
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
|
@patch('torch_npu.npu_attention_update')
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
|
|
mock_npu_multi_head_latent_attention,
|
|
mock_get_forward_context):
|
|
self.impl.dcp_size = 2
|
|
self.impl.pcp_size = 2
|
|
self.impl.num_kv_heads = 1
|
|
self.impl.num_heads = 16
|
|
self.impl.kv_lora_rank = 64
|
|
self.impl.qk_nope_head_dim = 64
|
|
self.impl.spec_token_num = 1
|
|
B = 2
|
|
N = self.impl.num_heads * self.impl.dcp_size
|
|
BS = 128
|
|
NB = 100
|
|
|
|
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
|
|
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
|
|
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
|
|
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)
|
|
|
|
attn_metadata = MagicMock()
|
|
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
|
|
attn_metadata.decode = MagicMock()
|
|
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
|
|
attn_metadata.decode.seq_lens_list = MagicMock()
|
|
attn_metadata.decode.batch_seq_mask = torch.tensor([False, False],
|
|
dtype=torch.bool)
|
|
|
|
self.impl.enable_kv_nz = True
|
|
|
|
mock_npu_attention_update.return_value = (torch.randn(
|
|
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
|
|
mock_npu_multi_head_latent_attention.return_value = [
|
|
torch.randn(B, N, self.impl.kv_lora_rank),
|
|
torch.randn(B, N, 1)
|
|
]
|
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
|
|
|
self.impl._v_up_proj = MagicMock()
|
|
self.impl._v_up_proj.return_value = torch.randn(
|
|
B, self.impl.v_head_dim)
|
|
|
|
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
|
attn_metadata)
|
|
|
|
self.assertEqual(result.shape[0], B)
|
|
self.assertEqual(result.shape[1], self.impl.v_head_dim)
|
|
|
|
@patch("torch_npu.atb.npu_paged_cache_load")
|
|
@patch("torch_npu.atb.npu_ring_mla")
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
|
def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
|
mock_pcp, mock_ring,
|
|
mock_load):
|
|
|
|
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
|
|
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
|
|
kernel_type, mask_type, input_layout, calc_type,
|
|
output, softmax_lse):
|
|
return torch.randn(q_rope.shape[0], value.shape[1], value.shape[2])
|
|
|
|
mock_ring.side_effect = mock_ring_attn
|
|
|
|
def mock_kv_b_proj(kv_c_normed):
|
|
return (torch.randn(kv_c_normed.shape[0],
|
|
self.impl.num_heads,
|
|
self.impl.v_head_dim +
|
|
self.impl.qk_nope_head_dim,
|
|
dtype=torch.float16), )
|
|
|
|
def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
|
|
allgatered_k_pe: torch.Tensor,
|
|
chunked_context: CPChunkedContextMetadata,
|
|
chunk_idx: int, toks: int):
|
|
return torch.randn(
|
|
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
|
|
allgatered_kv_c_normed.shape[1],
|
|
allgatered_kv_c_normed.shape[2]), torch.randn(
|
|
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
|
|
allgatered_k_pe.shape[1], allgatered_k_pe.shape[2])
|
|
|
|
# mock proj
|
|
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
|
|
NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed
|
|
USED_BLOCKS = 3
|
|
# pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp
|
|
test_cases = [
|
|
(2, 2, [4], [128], 1, 0, 1, 1, [[[32, 32], [32, 32]]]),
|
|
(1, 2, [4], [128], 1, 0, 1, 1, [[[64, 64]]]),
|
|
(2, 1, [4], [128], 1, 0, 1, 1, [[[64], [64]]]),
|
|
(2, 2, [4, 7], [128, 128], 2, 0, 2, 1, [[[32, 32], [32, 32]],
|
|
[[32, 32], [32, 32]]]),
|
|
]
|
|
# kv cache tensor
|
|
kv_cache_0 = torch.randn(NUM_BLOCKS,
|
|
BLOCK_SIZE,
|
|
self.impl.num_heads,
|
|
self.impl.kv_lora_rank,
|
|
dtype=torch.float16)
|
|
kv_cache_1 = torch.randn(NUM_BLOCKS,
|
|
BLOCK_SIZE,
|
|
self.impl.num_heads,
|
|
self.impl.v_head_dim,
|
|
dtype=torch.float16)
|
|
kv_cache = [kv_cache_0, kv_cache_1]
|
|
max_model_len = 4096
|
|
max_num_seqs = 25
|
|
# create chunk context
|
|
chunked_prefill_workspace_size = min(
|
|
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE), 128 * 1024)
|
|
self.impl.prefill_mask = torch.triu(
|
|
torch.ones(10, 10, dtype=torch.float16), 1)
|
|
for test_case in test_cases:
|
|
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
|
mock_dcp.world_size = dcp_size
|
|
mock_pcp.world_size = pcp_size
|
|
assert len(nums_tokens_per_rank) == len(nums_all_rank_context)
|
|
nums_context_per_rank = []
|
|
for num_all_rank_context in nums_all_rank_context:
|
|
assert num_all_rank_context % (pcp_size * dcp_size) == 0
|
|
nums_context_per_rank.append(num_all_rank_context //
|
|
(pcp_size * dcp_size))
|
|
self.impl.dcp_size = dcp_size
|
|
self.impl.pcp_size = pcp_size
|
|
# create input
|
|
query = torch.randn(sum(nums_tokens_per_rank),
|
|
self.impl.num_heads,
|
|
self.impl.qk_head_dim,
|
|
dtype=torch.float16)
|
|
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
|
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
|
prefix_out = torch.randn(sum(nums_tokens_per_rank),
|
|
self.impl.num_heads,
|
|
self.impl.v_head_dim,
|
|
dtype=torch.float16)
|
|
prefix_lse = torch.randn(sum(nums_tokens_per_rank),
|
|
self.impl.num_heads,
|
|
dtype=torch.float16)
|
|
chunk_ctx = get_chunk_metadata(
|
|
pcp_size,
|
|
dcp_size,
|
|
num_prefills=num_prefills,
|
|
num_decodes=num_decodes,
|
|
block_size=BLOCK_SIZE,
|
|
num_computed_tokens_cpu=torch.tensor(nums_all_rank_context),
|
|
num_reqs=num_seqs,
|
|
chunked_prefill_workspace_size=chunked_prefill_workspace_size,
|
|
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp,
|
|
cp_local_block_size=cp_local_block_size)
|
|
meta = MagicMock()
|
|
prefill_meta = MagicMock()
|
|
prefill_meta.query_lens = nums_tokens_per_rank
|
|
prefill_meta.block_table = torch.randint(
|
|
0, USED_BLOCKS, (1, 64)) # (batch, max_blocks)
|
|
prefill_meta.chunked_context = chunk_ctx
|
|
meta.prefill = prefill_meta
|
|
|
|
with patch.object(self.impl, '_reorg_kvcache') as mock_reorg:
|
|
mock_reorg.side_effect = mock_reorg_kvcache
|
|
|
|
out, lse = self.impl._compute_prefill_context(
|
|
q_nope, q_pe, kv_cache, self.impl.qk_rope_head_dim, meta,
|
|
prefix_out, prefix_lse)
|
|
|
|
iters = len(chunk_ctx.seq_tot)
|
|
self.impl.dcp_size = 1
|
|
self.impl.pcp_size = 1
|
|
self.assertEqual(mock_reorg.call_count,
|
|
iters * (1 if dcp_size * pcp_size > 1 else 0))
|
|
self.assertEqual(mock_load.call_count, iters)
|
|
self.assertEqual(mock_ring.call_count, iters)
|
|
mock_reorg.reset_mock()
|
|
mock_load.reset_mock()
|
|
mock_ring.reset_mock()
|
|
mock_dcp.reset_mock()
|
|
mock_pcp.reset_mock()
|
|
self.assertEqual(out.shape, prefix_out.shape)
|
|
self.assertEqual(lse.shape, prefix_lse.shape)
|
|
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
|
def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
|
mock_pcp):
|
|
BLOCK_SIZE = 128 # fixed
|
|
max_model_len = 4096
|
|
max_num_seqs = 25
|
|
test_cases = [
|
|
(2, 2, [4], [128], 1, 0, 1, 1, [[[32, 32], [32, 32]]]),
|
|
(1, 2, [4], [128], 1, 0, 1, 1, [[[64, 64]]]),
|
|
(2, 1, [4], [128], 1, 0, 1, 1, [[[64], [64]]]),
|
|
(2, 2, [4, 7], [128, 128], 2, 0, 2, 1, [[[32, 32], [32, 32]],
|
|
[[32, 32], [32, 32]]]),
|
|
]
|
|
for test_case in test_cases:
|
|
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
|
if pcp_size * dcp_size == 1:
|
|
continue
|
|
self.impl.dcp_size = dcp_size
|
|
mock_dcp.world_size = dcp_size
|
|
mock_dcp.all_gather.reset_mock()
|
|
self.impl.pcp_size = pcp_size
|
|
mock_pcp.world_size = pcp_size
|
|
mock_pcp.all_gather.reset_mock()
|
|
|
|
chunked_prefill_workspace_size = min(
|
|
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
|
|
128 * 1024)
|
|
chunked_context = get_chunk_metadata(
|
|
pcp_size,
|
|
dcp_size,
|
|
num_prefills=num_prefills,
|
|
num_decodes=num_decodes,
|
|
block_size=BLOCK_SIZE,
|
|
num_computed_tokens_cpu=torch.tensor(nums_all_rank_context),
|
|
num_reqs=num_seqs,
|
|
chunked_prefill_workspace_size=chunked_prefill_workspace_size,
|
|
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp,
|
|
cp_local_block_size=cp_local_block_size)
|
|
|
|
for i in range(len(chunked_context.seq_tot)):
|
|
allgatered_kv_c_normed = torch.randn(
|
|
chunked_context.seq_tot[i], self.impl.num_heads,
|
|
self.impl.kv_lora_rank)
|
|
allgatered_k_pe = torch.randn(chunked_context.seq_tot[i],
|
|
self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim)
|
|
result_kv, result_k_pe = self.impl._reorg_kvcache(
|
|
allgatered_kv_c_normed,
|
|
allgatered_k_pe,
|
|
chunked_context,
|
|
chunk_idx=i,
|
|
toks=chunked_context.seq_tot[i],
|
|
)
|
|
self.assertEqual(result_kv.shape,
|
|
(chunked_context.cu_seq_lens_lst[i][-1],
|
|
self.impl.num_heads, self.impl.kv_lora_rank))
|
|
self.assertEqual(
|
|
result_k_pe.shape,
|
|
(chunked_context.cu_seq_lens_lst[i][-1],
|
|
self.impl.num_heads, self.impl.qk_rope_head_dim))
|
|
|
|
self.assertEqual(result_kv.shape[0],
|
|
chunked_context.cu_seq_lens_lst[i][-1])
|
|
self.assertEqual(result_k_pe.shape[0],
|
|
chunked_context.cu_seq_lens_lst[i][-1])
|
|
|
|
self.assertEqual(mock_dcp.all_gather.call_count,
|
|
(1 if dcp_size > 1 else 0))
|
|
self.assertEqual(mock_pcp.all_gather.call_count,
|
|
(1 if pcp_size > 1 else 0))
|
|
|
|
def test_out_lse_reshape(self):
|
|
test_cases = [10, 1, 128, 512]
|
|
for test_case in test_cases:
|
|
num_tokens = test_case
|
|
num_heads, head_dim = self.impl.num_heads, self.impl.v_head_dim
|
|
attn_out = torch.randn(num_tokens, num_heads, head_dim)
|
|
attn_lse = torch.randn(num_tokens, num_heads, 1)
|
|
|
|
out, lse = self.impl._out_lse_reshape(attn_out, attn_lse)
|
|
|
|
assert out.shape == (num_tokens * num_heads, head_dim)
|
|
assert out.is_contiguous()
|
|
|
|
assert lse.shape == (num_tokens * num_heads, )
|
|
assert lse.is_contiguous()
|
|
|
|
expected_out = attn_out.contiguous().view(-1, head_dim)
|
|
expected_lse = attn_lse.contiguous().view(-1)
|
|
|
|
assert torch.allclose(out, expected_out)
|
|
assert torch.allclose(lse, expected_lse)
|
|
|
|
@patch('torch_npu.npu_attention_update')
|
|
@patch('vllm_ascend.attention.context_parallel.common_cp.get_pcp_group')
|
|
@patch('vllm.distributed.parallel_state._PCP',
|
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
|
@patch('vllm_ascend.attention.context_parallel.common_cp.get_dcp_group')
|
|
@patch('vllm.distributed.parallel_state._DCP',
|
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
|
def test_npu_attention_update_with_dcp_pcp(self, mock_dcp,
|
|
mock_get_dcp_group, mock_pcp,
|
|
mock_get_pcp_group,
|
|
mock_npu_attention_update):
|
|
NUM_TOKENS = 10 # fixed
|
|
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)]
|
|
for test_case in test_cases:
|
|
self.impl.dcp_size = test_case[0]
|
|
self.impl.pcp_size = test_case[1]
|
|
num_heads, head_dim = self.impl.num_heads, self.impl.kv_lora_rank + 1
|
|
|
|
def mock_out_lse_reshape(attn_out, attn_lse):
|
|
attn_out = attn_out.contiguous().view(
|
|
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
|
|
attn_lse = attn_lse.contiguous().view(
|
|
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
|
return attn_out, attn_lse
|
|
|
|
self.impl._out_lse_reshape = MagicMock()
|
|
self.impl._out_lse_reshape.side_effect = mock_out_lse_reshape
|
|
|
|
def mock_npu_attention_update_effect(attn_lse_split_cp,
|
|
attn_out_split_cp,
|
|
update_type):
|
|
return torch.randn_like(
|
|
attn_out_split_cp[0]), torch.randn_like(
|
|
attn_lse_split_cp[0])
|
|
|
|
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
|
|
|
|
mock_pcp_group = MagicMock()
|
|
mock_pcp_group.world_size = self.impl.pcp_size
|
|
mock_get_pcp_group.return_value = mock_pcp_group
|
|
|
|
mock_dcp.world_size = self.impl.dcp_size
|
|
mock_dcp_group = MagicMock()
|
|
# mock_dcp_group.world_size = self.impl.dcp_size
|
|
mock_get_dcp_group.return_value = mock_dcp_group
|
|
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
|
|
self.impl.dcp_size * num_heads,
|
|
head_dim)
|
|
out = _npu_attention_update(self.impl.kv_lora_rank, attn_out_lse)
|
|
self.impl.dcp_size = 1
|
|
self.impl.pcp_size = 1
|
|
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
|
|
|
|
@patch('torch_npu.atb.npu_ring_mla')
|
|
def test_attention_with_mask_and_nomask_with_dcp_pcp(
|
|
self, mock_npu_ring_mla):
|
|
num_heads = self.impl.num_heads
|
|
v_head_dim = self.impl.v_head_dim
|
|
qk_nope_head_dim = self.impl.qk_nope_head_dim
|
|
qk_rope_head_dim = self.impl.qk_rope_head_dim
|
|
|
|
def mock_npu_ring_mla_effect(q_nope, q_rope, k_nope, k_rope, value,
|
|
mask, seqlen, head_num, kv_head_num,
|
|
pre_out, prev_lse, qk_scale, kernel_type,
|
|
mask_type, input_layout, calc_type,
|
|
output, softmax_lse):
|
|
|
|
return torch.randn(q_nope.shape[0], value.shape[1],
|
|
value.shape[-1])
|
|
|
|
mock_npu_ring_mla.side_effect = mock_npu_ring_mla_effect
|
|
test_cases = [([8], 2, 2), ([8], 2, 1), ([8], 1, 2), ([8], 2, 2),
|
|
([8, 12], 2, 2)]
|
|
for test_case in test_cases:
|
|
scheduled_tokens, pcp_size, dcp_size = test_case
|
|
nums_tokens_per_rank = []
|
|
for num_tokens in scheduled_tokens:
|
|
assert num_tokens % (2 * pcp_size) == 0
|
|
nums_tokens_per_rank.append(num_tokens // pcp_size)
|
|
seq_len_q, seq_len_k = sum(nums_tokens_per_rank), sum(
|
|
scheduled_tokens)
|
|
q_nope = torch.randn(seq_len_q,
|
|
num_heads,
|
|
qk_nope_head_dim,
|
|
dtype=torch.float16)
|
|
q_pe = torch.randn(seq_len_q,
|
|
num_heads,
|
|
qk_rope_head_dim,
|
|
dtype=torch.float16)
|
|
k_nope = torch.randn(seq_len_k,
|
|
num_heads,
|
|
qk_nope_head_dim,
|
|
dtype=torch.float16)
|
|
k_pe = torch.randn(seq_len_k,
|
|
num_heads,
|
|
qk_rope_head_dim,
|
|
dtype=torch.float16)
|
|
value = torch.randn(seq_len_k,
|
|
num_heads,
|
|
v_head_dim,
|
|
dtype=torch.float16)
|
|
mask = torch.triu(torch.ones(10, 10, dtype=torch.float16), 1)
|
|
for rank in range(pcp_size):
|
|
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
|
|
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
|
|
rank, pcp_size, nums_tokens_per_rank)
|
|
kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx]
|
|
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
|
|
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
|
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
|
k_nope=k_nope,
|
|
k_pe=k_pe,
|
|
value=value,
|
|
kv_mask_idx=kv_with_q_head_mask_idx,
|
|
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
|
attn_mask_seqlens=torch.tensor(
|
|
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
|
attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens],
|
|
mask=mask)
|
|
self.assertEqual(output_head.shape,
|
|
(q_head_idx.shape[0], num_heads, v_head_dim))
|
|
self.assertEqual(lse_head.shape,
|
|
(num_heads, q_head_idx.shape[0]))
|
|
self.assertEqual(mock_npu_ring_mla.call_count,
|
|
1 + (len(kv_with_q_head_nomask_idx[0]) != 0))
|
|
mock_npu_ring_mla.reset_mock()
|
|
kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx]
|
|
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
|
|
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
|
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
|
k_nope=k_nope,
|
|
k_pe=k_pe,
|
|
value=value,
|
|
kv_mask_idx=kv_with_q_tail_mask_idx,
|
|
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
|
attn_mask_seqlens=torch.tensor(
|
|
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
|
attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens],
|
|
mask=mask)
|
|
|
|
self.assertEqual(output_tail.shape,
|
|
(q_tail_idx.shape[0], num_heads, v_head_dim))
|
|
self.assertEqual(lse_tail.shape,
|
|
(num_heads, q_tail_idx.shape[0]))
|
|
self.assertEqual(mock_npu_ring_mla.call_count,
|
|
1 + (len(kv_with_q_tail_nomask_idx[0]) != 0))
|
|
mock_npu_ring_mla.reset_mock()
|
|
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
|
def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp,
|
|
mock_pcp):
|
|
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
|
|
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
|
|
for test_case in test_cases:
|
|
self.impl.dcp_size = test_case[0]
|
|
self.impl.pcp_size = test_case[1]
|
|
mock_dcp.world_size = test_case[0]
|
|
mock_pcp.world_size = test_case[1]
|
|
# Inputs
|
|
attn_output = torch.randn(B, H, D)
|
|
softmax_lse = torch.randn(B, H, 1)
|
|
batch_seq_mask = torch.tensor([False, True, False, False]) # [B]
|
|
decode_meta = MagicMock()
|
|
decode_meta.batch_seq_mask = batch_seq_mask
|
|
|
|
result = _process_attn_out_lse(attn_output, softmax_lse,
|
|
batch_seq_mask)
|
|
# [PCP * S, DCP * H, D + 1]
|
|
self.assertIsInstance(result, torch.Tensor)
|
|
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
|
|
self.impl.dcp_size = 1
|
|
self.impl.pcp_size = 1
|
|
|
|
@patch('torch_npu.atb.npu_ring_mla')
|
|
def test_forward_prefill_cp_with_dcp_pcp(self, mock_npu_ring_mla):
|
|
|
|
def mock_attention_with_nomask_and_mask(
|
|
q_nope: torch.Tensor, q_pe: torch.Tensor, k_nope: torch.Tensor,
|
|
k_pe: torch.Tensor, value: torch.Tensor,
|
|
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
|
|
attn_mask_seqlens: torch.Tensor,
|
|
attn_nomask_seqlens: torch.Tensor, mask: torch.Tensor):
|
|
mock_output = torch.randn(q_nope.shape[0],
|
|
self.impl.num_heads,
|
|
self.impl.v_head_dim,
|
|
dtype=k_pe.dtype,
|
|
device=k_pe.device)
|
|
mock_lse = torch.randn(self.impl.num_heads,
|
|
q_pe.shape[0],
|
|
dtype=torch.float32,
|
|
device=k_pe.device)
|
|
return mock_output, mock_lse
|
|
|
|
def mock_compute_prefill_context(q_nope, q_pe, kv_c_and_k_pe_cache,
|
|
rope_dim, attn_metadata,
|
|
prefix_output, prefix_lse):
|
|
mock_output = torch.randn_like(prefix_output)
|
|
mock_lse = torch.randn_like(prefix_lse)
|
|
return mock_output, mock_lse
|
|
|
|
def mock_npu_ring_mla_effect(q_nope, q_rope, k_nope, k_rope, value,
|
|
mask, seqlen, head_num, kv_head_num,
|
|
pre_out, prev_lse, qk_scale, kernel_type,
|
|
mask_type, input_layout, calc_type,
|
|
output, softmax_lse):
|
|
return torch.randn(q_nope.shape[0], value.shape[1],
|
|
value.shape[-1])
|
|
|
|
self.impl._attention_with_mask_and_nomask = MagicMock()
|
|
self.impl._attention_with_mask_and_nomask.side_effect = mock_attention_with_nomask_and_mask
|
|
self.impl._compute_prefill_context = MagicMock()
|
|
self.impl._compute_prefill_context.side_effect = mock_compute_prefill_context
|
|
mock_npu_ring_mla.side_effect = mock_npu_ring_mla_effect
|
|
block_num = 10
|
|
block_size = 32
|
|
kv_c_and_k_pe_cache = (torch.randn(block_num,
|
|
block_size,
|
|
1,
|
|
self.impl.q_lora_rank,
|
|
dtype=torch.float16),
|
|
torch.randn(block_num,
|
|
block_size,
|
|
1,
|
|
self.impl.qk_rope_head_dim,
|
|
dtype=torch.float16))
|
|
test_cases = [([8], 2, 2), ([8], 2, 1), ([8], 1, 2), ([8], 2, 2),
|
|
([8, 16], 2, 2)]
|
|
for test_case in test_cases:
|
|
scheduled_tokens, pcp_size, dcp_size = test_case
|
|
nums_tokens_per_rank = []
|
|
for num_tokens in scheduled_tokens:
|
|
assert num_tokens % (
|
|
2 * pcp_size) == 0 # padded head&tail compute balance
|
|
nums_tokens_per_rank.append(num_tokens // pcp_size)
|
|
seq_len_q, seq_len_k = sum(nums_tokens_per_rank), sum(
|
|
scheduled_tokens)
|
|
|
|
q_nope = torch.randn(seq_len_q,
|
|
self.impl.num_heads,
|
|
self.impl.qk_nope_head_dim,
|
|
dtype=torch.float16)
|
|
q_pe = torch.randn(seq_len_q,
|
|
self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim,
|
|
dtype=torch.float16)
|
|
k_nope = torch.randn(seq_len_k,
|
|
self.impl.num_heads,
|
|
self.impl.qk_nope_head_dim,
|
|
dtype=torch.float16)
|
|
k_pe = torch.randn(seq_len_k,
|
|
self.impl.num_heads,
|
|
self.impl.qk_rope_head_dim,
|
|
dtype=torch.float16)
|
|
value = torch.randn(seq_len_k,
|
|
self.impl.num_heads,
|
|
self.impl.v_head_dim,
|
|
dtype=torch.float16)
|
|
# only test one rank
|
|
for rank in range(pcp_size):
|
|
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
|
|
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
|
|
rank, pcp_size, nums_tokens_per_rank)
|
|
attn_metadata = MagicMock()
|
|
attn_metadata.prefill = MagicMock()
|
|
attn_metadata.prefill.pcp_metadata.q_head_idx = q_head_idx
|
|
attn_metadata.prefill.pcp_metadata.q_tail_idx = q_tail_idx
|
|
attn_metadata.prefill.pcp_metadata.q_full_idx = torch.cat([
|
|
attn_metadata.prefill.pcp_metadata.q_head_idx,
|
|
attn_metadata.prefill.pcp_metadata.q_tail_idx
|
|
])
|
|
attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx = kv_with_q_head_nomask_idx
|
|
|
|
attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx = kv_with_q_head_mask_idx
|
|
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx = kv_with_q_tail_nomask_idx
|
|
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = kv_with_q_tail_mask_idx
|
|
attn_metadata.prefill.pcp_metadata.attn_mask_seqlens = torch.tensor(
|
|
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
|
attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens = kv_with_q_head_nomask_seqlens
|
|
attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens = kv_with_q_tail_nomask_seqlens
|
|
|
|
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
|
|
value, kv_c_and_k_pe_cache,
|
|
attn_metadata)
|
|
self.assertEqual(
|
|
output.shape,
|
|
(seq_len_q, self.impl.num_heads * self.impl.v_head_dim))
|