[Refact]Refact MLA/SFA weight prefetch to consist with moe weight prefetch (#6629)

### What this PR does / why we need it?
1. [Refact] Refact MLA/SFA weight prefetch to consist with moe weight
prefetch
2. Remove duplicated o_proj weight prefetch in forward for MLA/SFA

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?

1) Performance result:
Perf test data:
*) MLA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 11.9669 token/s | 12.0287 token/s |
11.9978 |
| o_proj no duplicate prefetch | 12.5594 token/s | 12.6216 token/s |
12.5905 | 4.94%| |

single layer performace improve: 5%~8%

*) SFA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 13.0523 token/s | 13.1084 token/s |
13.08035 | |
| o_proj no duplicate prefetch | 13.9844 token/s | 14.1678 token/s |
14.0761 | 7.6% |

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Nengjun Ma
2026-02-10 14:14:37 +08:00
committed by GitHub
parent 2a826b5fad
commit 66b60c9440
15 changed files with 98 additions and 56 deletions

View File

@@ -248,9 +248,10 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(self.impl.dcp_size, 2)
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_mla_preprocess_dcp(self, magic_npu_fetch,
def test_mla_preprocess_dcp(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad):
self.impl.num_kv_heads = 1
@@ -309,7 +310,6 @@ class TestAscendMLAImpl(TestBase):
self.impl.qk_rope_head_dim)
]
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
decode_res, prefill_res = self.impl._mla_preprocess(
@@ -324,9 +324,10 @@ class TestAscendMLAImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_mla_preprocess_pcp(self, magic_npu_fetch,
def test_mla_preprocess_pcp(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad,
mock_npu_reshape_and_cache):
self.impl.num_kv_heads = 1
@@ -389,7 +390,6 @@ class TestAscendMLAImpl(TestBase):
self.impl.qk_rope_head_dim)
]
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
self.impl.kv_a_layernorm = MagicMock()

View File

@@ -967,10 +967,10 @@ class TestAscendMLAImpl(TestBase):
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch,
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
def test_mla_preprocess(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4
seq_len = 8