[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)

### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": false,
        "prefetch_ratio": {
            "attn": {
                "qkv": 1.0,
                "o": 1.0,
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
Ruri
2025-10-09 20:38:39 +08:00
committed by GitHub
parent 23db56a340
commit ff37575936
13 changed files with 264 additions and 69 deletions

View File

@@ -1,7 +1,7 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import CUDAGraphMode, VllmConfig
@@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class FusedMoEState(Enum):
AllGather = 0
@@ -65,7 +70,8 @@ def set_ascend_forward_context(
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None):
model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
@@ -127,6 +133,7 @@ def set_ascend_forward_context(
hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
@@ -138,6 +145,8 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
# TODO(yuzhup): integrate moe weight prefetch method
forward_context.weight_prefetch_method = weight_prefetch_method
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.