[Feat] Prefetching Attention QKV Linear Weight With AddRmsNormQuant Custom Op (#3517)

### What this PR does / why we need it?

- `qkv_proj.weight` prefetching has been implemented with `Quant` op,
when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching
won't work
- Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`

### Does this PR introduce _any_ user-facing change?

None.

### How was this patch tested?

Tested on `Qwen3-235B-A22B-W8A8`
<img width="1868" height="109" alt="image"
src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36"
/>


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
Ruri
2025-10-23 10:07:37 +08:00
committed by GitHub
parent 72695c97d0
commit dd7a25063c
2 changed files with 29 additions and 6 deletions

View File

@@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot(
torch_npu_check = version_check()
if layer is not None and not is_310p():
layer_cls_name = layer.__class__.__name__
try:
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
except AssertionError:
weight_prefetch_method = None
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
layer_cls_name=layer_cls_name,
weight=layer.weight,
start_flag=x,
)
# add_rms_norm_quant
if torch_npu_check:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
@@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot(
layer.aclnn_input_scale,
layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
# prefetch qkvo_proj.weight postprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
layer_cls_name=layer_cls_name,
stop_flag=x,
)
else:
if is_310p():
orig_dtype = residual.dtype