Add ut for mla (#2637)
### What this PR does / why we need it?
Update UT for MLA case
- vLLM version: v0.10.1.1
- vLLM main:
14b4326b94
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -456,3 +456,170 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
|
||||
def test_mla_preprocess(self, magic_npu_fetch):
|
||||
magic_npu_fetch.return_value = MagicMock()
|
||||
batch_size = 4
|
||||
seq_len = 8
|
||||
hidden_size = 1024
|
||||
hidden_states = torch.randn(batch_size * seq_len, hidden_size)
|
||||
|
||||
kv_cache = MagicMock()
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.num_decodes = 2
|
||||
attn_metadata.num_prefills = 2
|
||||
attn_metadata.num_decode_tokens = 2
|
||||
attn_metadata.num_actual_tokens = 4
|
||||
num_prefill_tokens = 2
|
||||
attn_metadata.slot_mapping = torch.arange(4)
|
||||
attn_metadata.decode.cos = torch.randn(2, 64)
|
||||
attn_metadata.decode.sin = torch.randn(2, 64)
|
||||
attn_metadata.prefill.cos = torch.randn(2, 64)
|
||||
attn_metadata.prefill.sin = torch.randn(2, 64)
|
||||
|
||||
self.impl.q_a_proj = MagicMock()
|
||||
self.impl.q_a_layernorm = MagicMock()
|
||||
self.impl.q_a_layernorm.return_value = torch.randn(
|
||||
attn_metadata.num_actual_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
self.impl.kv_a_proj_with_mqa = MagicMock()
|
||||
self.impl.kv_a_proj_with_mqa.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)
|
||||
]
|
||||
self.impl.q_proj = MagicMock()
|
||||
self.impl.q_proj.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
]
|
||||
self.impl.kv_b_proj = MagicMock()
|
||||
self.impl.kv_b_proj.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
|
||||
]
|
||||
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
|
||||
self.impl.exec_kv_decode = MagicMock()
|
||||
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
|
||||
self.impl.exec_kv_prefill = MagicMock()
|
||||
self.impl.exec_kv_prefill.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim),
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
]
|
||||
self.impl._q_proj_and_k_up_proj = MagicMock()
|
||||
self.impl._q_proj_and_k_up_proj.return_value = [
|
||||
MagicMock(), MagicMock()
|
||||
]
|
||||
self.impl.num_kv_heads = self.impl.num_heads
|
||||
|
||||
decode_res, prefill_res = self.impl._mla_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
|
||||
|
||||
self.assertIsNotNone(decode_res)
|
||||
self.assertIsNotNone(prefill_res)
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
|
||||
kv_no_split = torch.randn(B, N, D)
|
||||
self.impl.enable_kv_nz = None
|
||||
self.impl.kv_a_layernorm.weight = MagicMock()
|
||||
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
|
||||
cos = MagicMock()
|
||||
sin = MagicMock()
|
||||
slots = MagicMock()
|
||||
kv_cache = [MagicMock(), MagicMock()]
|
||||
|
||||
mock_kv_rmsnorm_rope_cache.return_value = [
|
||||
None, None,
|
||||
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
|
||||
torch.randn(B, N, 1, self.impl.kv_lora_rank)
|
||||
]
|
||||
|
||||
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin,
|
||||
kv_cache, slots)
|
||||
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
|
||||
kv_no_split = torch.randn(B, N, D)
|
||||
self.impl.enable_kv_nz = None
|
||||
self.impl.kv_a_layernorm.weight = MagicMock()
|
||||
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
|
||||
cos = MagicMock()
|
||||
sin = MagicMock()
|
||||
slots = MagicMock()
|
||||
kv_cache = [MagicMock(), MagicMock()]
|
||||
|
||||
mock_kv_rmsnorm_rope_cache.return_value = [
|
||||
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
|
||||
torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None
|
||||
]
|
||||
|
||||
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin,
|
||||
kv_cache, slots)
|
||||
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch("torch.npu.stream")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||
mock_get_multistream_comm_context,
|
||||
mock_npu_stream):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
BS = 100
|
||||
HD = self.impl.v_head_dim
|
||||
self.impl.kv_lora_rank = 256
|
||||
self.impl.spec_token_num = 1
|
||||
self.impl._v_up_proj = MagicMock()
|
||||
self.impl._v_up_proj.return_value = torch.randn(B, N, HD)
|
||||
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
|
||||
k_nope = torch.randn(BS, N, self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim)
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
|
||||
attn_metadata.decode = MagicMock()
|
||||
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
|
||||
attn_metadata.decode.seq_lens_list = MagicMock()
|
||||
self.impl.enable_kv_nz = True
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = [
|
||||
torch.randn(B, N, self.impl.kv_lora_rank), None
|
||||
]
|
||||
mock_get_multistream_comm_context.return_value = None
|
||||
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
|
||||
self.impl.enable_kv_nz = False
|
||||
attn_metadata.attn_state = None
|
||||
mock_return_value = MagicMock()
|
||||
mock_get_multistream_comm_context.return_value = mock_return_value
|
||||
mock_return_value.before_comm_event = MagicMock()
|
||||
mock_return_value.comm_stream = MagicMock()
|
||||
mock_npu_stream.return_value = MagicMock()
|
||||
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
Reference in New Issue
Block a user