[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": false,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -70,11 +70,12 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.sfa import Indexer
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
TorchairAscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable
|
||||
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
|
||||
|
||||
|
||||
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
|
||||
@@ -589,9 +590,9 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
and attn_metadata.num_decodes > 0)
|
||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||
if self.q_lora_rank is not None:
|
||||
npu_prefetch(self.q_a_proj.weight,
|
||||
hidden_states,
|
||||
enabled=enable_multistream_mla)
|
||||
maybe_npu_prefetch(self.q_a_proj.weight,
|
||||
hidden_states,
|
||||
enabled=enable_multistream_mla)
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
forward_kwargs['ckq'] = ckq
|
||||
|
||||
@@ -23,9 +23,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -684,10 +684,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
if hasattr(self, "running_in_graph") and not self.running_in_graph:
|
||||
return x
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
x,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
x,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
return self.o_proj(x, is_prefill=False)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
@@ -1281,10 +1281,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
if current_ms_metadata is None:
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
@@ -1292,10 +1292,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=True,
|
||||
|
||||
Reference in New Issue
Block a user