[Refactor][WIP] Refactor mla_v1 by moving all MLA preprocessing ops into mla_v1 attention impl (#2465)
### What this PR does / why we need it?
In order to support fused kernels, multi-stream, communication
optimization etc, it's better to aggregate all opreations in Attention
layer togather. This PR tries to refactor mla_v1 by moving all MLA
preprocessing ops into mla_v1 attention impl.
Note that new mla_v1 doesn't take torchair into consideration. So this
PR can only be merged after torchair related mla_v1 is isolated into a
new file.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
### Features Test
<img width="506" height="141" alt="image"
src="https://github.com/user-attachments/assets/f1ab2906-a1ac-4450-8433-94811cd89466"
/>
### Performance After Refact
<img width="648" height="486" alt="image"
src="https://github.com/user-attachments/assets/e33e038c-c5d9-4ba7-a8e9-1ac22f9833eb"
/>
### Performance Before Refact
<img width="618" height="494" alt="image"
src="https://github.com/user-attachments/assets/83861dc2-dc51-4af3-9310-90ab10c43bb1"
/>
- vLLM version: v0.10.1.1
- vLLM main:
e03940762b
---------
Signed-off-by: lwq <liwenquan5@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -210,6 +210,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
||||
builder.decode_threshold = 1
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
@@ -303,18 +304,16 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
||||
self.assertEqual(self.impl.tp_size, 2)
|
||||
|
||||
def test_v_up_proj_and_o_proj(self):
|
||||
def test_v_up_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
|
||||
self.impl.o_proj.return_value = (torch.randn(
|
||||
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
|
||||
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
|
||||
self.impl.W_UV = torch.randn(self.impl.num_heads,
|
||||
self.impl.kv_lora_rank,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._v_up_proj_and_o_proj(x)
|
||||
result = self.impl._v_up_proj(x)
|
||||
|
||||
self.assertEqual(result.shape[0], batch_size)
|
||||
self.assertEqual(result.shape[1],
|
||||
@@ -371,8 +370,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
metadata.prefill = None
|
||||
prefix_out = torch.randn(2, 16, 128)
|
||||
prefix_lse = torch.randn(2, 16, 8)
|
||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
||||
metadata, prefix_out,
|
||||
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
||||
|
||||
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
||||
32, metadata, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
self.assertTrue(torch.equal(prefix_out, out))
|
||||
@@ -386,6 +388,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
latent_kv_dim = self.impl.kv_lora_rank
|
||||
num_blocks, block_size = 100, 20
|
||||
query = torch.randn(S, N, D)
|
||||
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
||||
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||
kv_cache = [kv_cache_0, kv_cache_1]
|
||||
@@ -406,9 +410,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
meta = MagicMock()
|
||||
meta.prefill = prefill_meta
|
||||
self.impl.prefill_mask = torch.triu(
|
||||
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
|
||||
|
||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
||||
meta, prefix_out,
|
||||
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
||||
32, meta, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
mock_load.assert_called_once()
|
||||
@@ -417,67 +423,36 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj")
|
||||
@patch("torch_npu._npu_paged_attention_mla")
|
||||
def test_forward_decode_without_graph(self, mock_page_attention_mla,
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode_without_graph(self,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_up_proj):
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
k_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim)
|
||||
k_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
metadata = MagicMock()
|
||||
metadata.decode = MagicMock()
|
||||
metadata.decode.block_table = MagicMock()
|
||||
metadata.decode.seq_lens = 10
|
||||
mock_page_attention_mla.return_value = torch.randn(
|
||||
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
|
||||
mock_npu_fused_infer_attention_score.return_value = [
|
||||
torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank), None
|
||||
]
|
||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, None, None,
|
||||
kv_c_and_k_pe_cache, metadata)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
|
||||
block_size, metadata)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
self.assertEqual(result.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_page_attention_mla.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
def test_forward_without_graph(self, _, mock_forward_prefill):
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
rotary_emb_return_value = (torch.randn(num_tokens, 16,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(0, 1, self.impl.kv_lora_rank))
|
||||
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
|
||||
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
|
||||
1, num_blocks, 128)
|
||||
|
||||
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
|
||||
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
|
||||
self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
|
||||
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim))
|
||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = num_tokens
|
||||
mock_forward_prefill.return_value = torch.randn(
|
||||
0, self.impl.num_heads * self.impl.v_head_dim)
|
||||
result = self.impl.forward(None, hidden_states_or_q_c,
|
||||
hidden_states_or_kv_c_normed, k_pe,
|
||||
kv_cache, metadata, output, False)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user