[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": false,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -11,6 +11,8 @@ from vllm.utils import direct_register_custom_op
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
@@ -148,6 +150,33 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
weight_prefetch_stream.wait_stream(calculation_stream)
|
||||
with npu_stream_switch(weight_prefetch_stream):
|
||||
maybe_npu_prefetch(inputs=weight,
|
||||
dependency=start_flag,
|
||||
max_size=max_weight_size)
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
|
||||
start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
calculation_stream.wait_stream(weight_prefetch_stream)
|
||||
|
||||
|
||||
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
@@ -194,6 +223,18 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="prefetch_preprocess",
|
||||
op_func=_prefetch_preprocess_impl,
|
||||
fake_impl=_prefetch_preprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="prefetch_postprocess",
|
||||
op_func=_prefetch_postprocess_impl,
|
||||
fake_impl=_prefetch_postprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||
fake_impl=lambda x: x,
|
||||
|
||||
75
vllm_ascend/ops/weight_prefetch.py
Normal file
75
vllm_ascend/ops/weight_prefetch.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleWeightPrefetchConfig:
|
||||
module_name: str
|
||||
enable: bool = False
|
||||
prefetch_ratio: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.prefetch_ratio = {
|
||||
prefix: ratio
|
||||
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||
}
|
||||
|
||||
assert self.module_name in SUPPORTED_MODULES, (
|
||||
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||
)
|
||||
|
||||
if self.module_name in SUPPORTED_MODULES:
|
||||
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
||||
|
||||
|
||||
class WeightPrefetchMethod:
|
||||
"""
|
||||
Unified weight prefetch method.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||
self.attn = ModuleWeightPrefetchConfig(
|
||||
module_name="attn",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"attn", {}))
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, prefix: str, weight: torch.Tensor,
|
||||
start_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable:
|
||||
return
|
||||
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=start_flag,
|
||||
max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_attn_weight_postprocess(
|
||||
self, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
|
||||
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
offset: int = 0,
|
||||
*,
|
||||
enabled: bool = True) -> None:
|
||||
if not enabled:
|
||||
return
|
||||
input_size = inputs.element_size() * inputs.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
||||
Reference in New Issue
Block a user