[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)

### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": false,
        "prefetch_ratio": {
            "attn": {
                "qkv": 1.0,
                "o": 1.0,
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
Ruri
2025-10-09 20:38:39 +08:00
committed by GitHub
parent 23db56a340
commit ff37575936
13 changed files with 264 additions and 69 deletions

View File

@@ -24,7 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -493,7 +493,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config()
@@ -877,9 +877,9 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.q_a_proj is not None:
npu_prefetch(self.q_a_proj.weight,
hidden_states,
enabled=self.enable_prefetch)
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
ckq = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(ckq)
else:
@@ -1005,10 +1005,10 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
@@ -1016,10 +1016,10 @@ class AscendMLAImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,