From 825fdfb197fd699e459a16a6c1a4ef7e05b400a6 Mon Sep 17 00:00:00 2001 From: Ruri <33858552+zhoux77899@users.noreply.github.com> Date: Mon, 27 Oct 2025 09:42:09 +0800 Subject: [PATCH] [v0.11.0][Feat] Prefetching Attention QKV Linear Weight With `AddRmsNormQuant` Custom Op (#3649) ### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`, which has been merged on `main` branch (#3517) ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` image - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: zhoux77899 --- tests/ut/ops/test_layernorm.py | 13 +++++++------ vllm_ascend/ops/layernorm.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index ed9dd44..7a829b6 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.layer_idx = 0 mock_forward_context.num_hidden_layers = num_hidden_layers mock_forward_context.fusion_linear = "gate_up_dense" + mock_forward_context.weight_prefetch_method = None # Ensure fusion and layer_idx increment are handled correctly x = torch.randn(4, 8, dtype=torch.float16) @@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase): x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 1 + assert mock_get_forward_context.call_count == 2 assert mock_forward_context.fusion_linear == "qkv_dense" assert mock_forward_context.layer_idx == 1 x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 2 + assert mock_get_forward_context.call_count == 4 assert mock_forward_context.fusion_linear == "gate_up_dense" assert mock_forward_context.layer_idx == 1 @@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.fusion_linear = "gate_moe" x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 3 + assert mock_get_forward_context.call_count == 6 fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense" assert mock_forward_context.fusion_linear == fusion_linear_expected assert mock_forward_context.layer_idx == 2 x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 4 + assert mock_get_forward_context.call_count == 7 fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense" assert mock_forward_context.fusion_linear == fusion_linear_expected assert mock_forward_context.layer_idx == 2 @@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase): # last layer returned directly x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 5 + assert mock_get_forward_context.call_count == 8 assert mock_forward_context.fusion_linear == "qkv_moe" assert mock_forward_context.layer_idx == 3 x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 6 + assert mock_get_forward_context.call_count == 9 assert mock_forward_context.fusion_linear == "qkv_moe" assert mock_forward_context.layer_idx == 3 diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 55eab21..239eeb0 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot( torch_npu_check = version_check() if layer is not None and not is_310p(): + layer_cls_name = layer.__class__.__name__ + try: + weight_prefetch_method = get_forward_context( + ).weight_prefetch_method + except AssertionError: + weight_prefetch_method = None + + # prefetch qkvo_proj.weight preprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) + # add_rms_norm_quant if torch_npu_check: x, _, residual = torch_npu.npu_add_rms_norm_quant( x, @@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot( layer.aclnn_input_scale, layer.aclnn_input_offset, epsilon=self.variance_epsilon) + # prefetch qkvo_proj.weight postprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + layer_cls_name=layer_cls_name, + stop_flag=x, + ) + else: if is_310p(): orig_dtype = residual.dtype