[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": false,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
@@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
else:
|
||||
WeightPrefetchMethod = None
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
@@ -65,7 +70,8 @@ def set_ascend_forward_context(
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
prefetch_stream: torch.npu.Stream = None,
|
||||
model_instance: torch.nn.Module = None):
|
||||
model_instance: torch.nn.Module = None,
|
||||
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
@@ -127,6 +133,7 @@ def set_ascend_forward_context(
|
||||
hasattr(model_instance.model, "start_layer"):
|
||||
forward_context.layer_idx = model_instance.model.start_layer
|
||||
|
||||
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
||||
# set for mlp weight prefetch
|
||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
||||
@@ -138,6 +145,8 @@ def set_ascend_forward_context(
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
# TODO(yuzhup): integrate moe weight prefetch method
|
||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
# It will be improved later by implementing operator fusion through the FX graph.
|
||||
|
||||
Reference in New Issue
Block a user