[Refact]Refact MLA/SFA weight prefetch to consist with moe weight prefetch (#6629)

### What this PR does / why we need it?
1. [Refact] Refact MLA/SFA weight prefetch to consist with moe weight
prefetch
2. Remove duplicated o_proj weight prefetch in forward for MLA/SFA

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?

1) Performance result:
Perf test data:
*) MLA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 11.9669 token/s | 12.0287 token/s |
11.9978 |
| o_proj no duplicate prefetch | 12.5594 token/s | 12.6216 token/s |
12.5905 | 4.94%| |

single layer performace improve: 5%~8%

*) SFA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 13.0523 token/s | 13.1084 token/s |
13.08035 | |
| o_proj no duplicate prefetch | 13.9844 token/s | 14.1678 token/s |
14.0761 | 7.6% |

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Nengjun Ma
2026-02-10 14:14:37 +08:00
committed by GitHub
parent 2a826b5fad
commit 66b60c9440
15 changed files with 98 additions and 56 deletions

View File

@@ -43,9 +43,13 @@ from vllm_ascend.ops.layer_shard_linear import (
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
get_weight_prefetch_method,
maybe_trans_nz,
weak_ref_tensors,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
@@ -703,7 +707,6 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.enable_kv_nz
self.ring_mla_mask_size = 512
@@ -1412,8 +1415,9 @@ class AscendMLAImpl(MLAAttentionImpl):
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
@@ -1545,14 +1549,13 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch,
linear_layer=self.o_proj,
)
output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0]
del o_proj_input

View File

@@ -37,7 +37,6 @@ from vllm_ascend.ops.layer_shard_linear import (
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
@@ -45,6 +44,7 @@ from vllm_ascend.utils import (
dispose_layer,
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
get_weight_prefetch_method,
maybe_trans_nz,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -410,7 +410,6 @@ class AscendSFAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
# In sfa, prefill and decode have the same calculation formula,
# so do not distinguish between prefill and decode here.
@@ -800,8 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
@@ -917,11 +917,12 @@ class AscendSFAImpl(MLAAttentionImpl):
)
attn_output = self._v_up_proj(attn_output)
maybe_npu_prefetch(
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch,
linear_layer=self.o_proj,
)
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: