[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": false,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -51,6 +51,7 @@ _CUSTOM_OP_ENABLED = None
|
||||
_IS_310P = None
|
||||
_SLEEP_MODE_ENABLED = None
|
||||
_CURRENT_STREAM = None
|
||||
_PREFETCH_STREAM = None
|
||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||
|
||||
|
||||
@@ -241,6 +242,15 @@ def current_stream() -> torch.npu.Stream:
|
||||
return _CURRENT_STREAM
|
||||
|
||||
|
||||
def prefetch_stream() -> torch.npu.Stream:
|
||||
global _PREFETCH_STREAM
|
||||
if _PREFETCH_STREAM is None:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
_PREFETCH_STREAM = torch_npu.npu.Stream()
|
||||
return _PREFETCH_STREAM
|
||||
|
||||
|
||||
def adapt_patch(is_global_patch: bool = False):
|
||||
if is_global_patch:
|
||||
from vllm_ascend.patch import platform # noqa: F401
|
||||
@@ -446,20 +456,6 @@ class ProfileExecuteDuration:
|
||||
return durations
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def npu_prefetch(input: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
if not enabled:
|
||||
return
|
||||
input_size = input.element_size() * input.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(input, dependency, max_size)
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): rm_router_logits
|
||||
# dp>1 will trigger
|
||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||
|
||||
Reference in New Issue
Block a user