[Refactor][WIP] Refactor mla_v1 by moving all MLA preprocessing ops into mla_v1 attention impl (#2465)

### What this PR does / why we need it?
In order to support fused kernels, multi-stream, communication
optimization etc, it's better to aggregate all opreations in Attention
layer togather. This PR tries to refactor mla_v1 by moving all MLA
preprocessing ops into mla_v1 attention impl.
Note that new mla_v1 doesn't take torchair into consideration. So this
PR can only be merged after torchair related mla_v1 is isolated into a
new file.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?

### Features Test

<img width="506" height="141" alt="image"
src="https://github.com/user-attachments/assets/f1ab2906-a1ac-4450-8433-94811cd89466"
/>

### Performance After Refact
<img width="648" height="486" alt="image"
src="https://github.com/user-attachments/assets/e33e038c-c5d9-4ba7-a8e9-1ac22f9833eb"
/>

### Performance Before Refact
<img width="618" height="494" alt="image"
src="https://github.com/user-attachments/assets/83861dc2-dc51-4af3-9310-90ab10c43bb1"
/>


- vLLM version: v0.10.1.1
- vLLM main:
e03940762b

---------

Signed-off-by: lwq <liwenquan5@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
LeeWenquan
2025-08-28 10:35:57 +08:00
committed by GitHub
parent 320edde2df
commit c8d1df3a3f
5 changed files with 410 additions and 345 deletions

View File

@@ -210,6 +210,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
builder.decode_threshold = 1
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
@@ -303,18 +304,16 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
def test_v_up_proj_and_o_proj(self):
def test_v_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank)
self.impl.o_proj.return_value = (torch.randn(
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank,
self.impl.v_head_dim)
result = self.impl._v_up_proj_and_o_proj(x)
result = self.impl._v_up_proj(x)
self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1],
@@ -371,8 +370,11 @@ class TestAscendMLAImpl(TestBase):
metadata.prefill = None
prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
metadata, prefix_out,
q_pe = query[..., self.impl.qk_nope_head_dim:]
q_nope = query[..., :self.impl.qk_nope_head_dim]
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, metadata, prefix_out,
prefix_lse)
self.assertTrue(torch.equal(prefix_out, out))
@@ -386,6 +388,8 @@ class TestAscendMLAImpl(TestBase):
latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20
query = torch.randn(S, N, D)
q_nope = query[..., :self.impl.qk_nope_head_dim]
q_pe = query[..., self.impl.qk_nope_head_dim:]
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
@@ -406,9 +410,11 @@ class TestAscendMLAImpl(TestBase):
meta = MagicMock()
meta.prefill = prefill_meta
self.impl.prefill_mask = torch.triu(
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
meta, prefix_out,
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, meta, prefix_out,
prefix_lse)
mock_load.assert_called_once()
@@ -417,67 +423,36 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj")
@patch("torch_npu._npu_paged_attention_mla")
def test_forward_decode_without_graph(self, mock_page_attention_mla,
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode_without_graph(self,
mock_npu_fused_infer_attention_score,
mock_up_proj):
num_tokens = 100
num_blocks = 256
block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
self.impl.num_heads,
self.impl.kv_lora_rank)
k_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
k_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
metadata = MagicMock()
metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10
mock_page_attention_mla.return_value = torch.randn(
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(num_tokens, self.impl.num_heads,
self.impl.kv_lora_rank), None
]
mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads,
self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, None, None,
kv_c_and_k_pe_cache, metadata)
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
block_size, metadata)
self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)
metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)
mock_npu_fused_infer_attention_score.assert_called_once()