From 825fdfb197fd699e459a16a6c1a4ef7e05b400a6 Mon Sep 17 00:00:00 2001
From: Ruri <33858552+zhoux77899@users.noreply.github.com>
Date: Mon, 27 Oct 2025 09:42:09 +0800
Subject: [PATCH] [v0.11.0][Feat] Prefetching Attention QKV Linear Weight With
`AddRmsNormQuant` Custom Op (#3649)
### What this PR does / why we need it?
- `qkv_proj.weight` prefetching has been implemented with `Quant` op,
when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching
won't work
- Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`, which
has been merged on `main` branch (#3517)
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Tested on `Qwen3-235B-A22B-W8A8`
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Signed-off-by: zhoux77899
---
tests/ut/ops/test_layernorm.py | 13 +++++++------
vllm_ascend/ops/layernorm.py | 22 ++++++++++++++++++++++
2 files changed, 29 insertions(+), 6 deletions(-)
diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py
index ed9dd44..7a829b6 100644
--- a/tests/ut/ops/test_layernorm.py
+++ b/tests/ut/ops/test_layernorm.py
@@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = num_hidden_layers
mock_forward_context.fusion_linear = "gate_up_dense"
+ mock_forward_context.weight_prefetch_method = None
# Ensure fusion and layer_idx increment are handled correctly
x = torch.randn(4, 8, dtype=torch.float16)
@@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase):
x_out, residual_out = layer.forward_oot(x, residual)
- assert mock_get_forward_context.call_count == 1
+ assert mock_get_forward_context.call_count == 2
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 1
x_out, residual_out = layer.forward_oot(x, residual)
- assert mock_get_forward_context.call_count == 2
+ assert mock_get_forward_context.call_count == 4
assert mock_forward_context.fusion_linear == "gate_up_dense"
assert mock_forward_context.layer_idx == 1
@@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.fusion_linear = "gate_moe"
x_out, residual_out = layer.forward_oot(x, residual)
- assert mock_get_forward_context.call_count == 3
+ assert mock_get_forward_context.call_count == 6
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
assert mock_forward_context.fusion_linear == fusion_linear_expected
assert mock_forward_context.layer_idx == 2
x_out, residual_out = layer.forward_oot(x, residual)
- assert mock_get_forward_context.call_count == 4
+ assert mock_get_forward_context.call_count == 7
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
assert mock_forward_context.fusion_linear == fusion_linear_expected
assert mock_forward_context.layer_idx == 2
@@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase):
# last layer returned directly
x_out, residual_out = layer.forward_oot(x, residual)
- assert mock_get_forward_context.call_count == 5
+ assert mock_get_forward_context.call_count == 8
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
x_out, residual_out = layer.forward_oot(x, residual)
- assert mock_get_forward_context.call_count == 6
+ assert mock_get_forward_context.call_count == 9
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py
index 55eab21..239eeb0 100644
--- a/vllm_ascend/ops/layernorm.py
+++ b/vllm_ascend/ops/layernorm.py
@@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot(
torch_npu_check = version_check()
if layer is not None and not is_310p():
+ layer_cls_name = layer.__class__.__name__
+ try:
+ weight_prefetch_method = get_forward_context(
+ ).weight_prefetch_method
+ except AssertionError:
+ weight_prefetch_method = None
+
+ # prefetch qkvo_proj.weight preprocess
+ if weight_prefetch_method:
+ weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
+ layer_cls_name=layer_cls_name,
+ weight=layer.weight,
+ start_flag=x,
+ )
+ # add_rms_norm_quant
if torch_npu_check:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
@@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot(
layer.aclnn_input_scale,
layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
+ # prefetch qkvo_proj.weight postprocess
+ if weight_prefetch_method:
+ weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
+ layer_cls_name=layer_cls_name,
+ stop_flag=x,
+ )
+
else:
if is_310p():
orig_dtype = residual.dtype