[Feat] Prefetching Attention QKV Linear Weight With AddRmsNormQuant Custom Op (#3517)

### What this PR does / why we need it?

- `qkv_proj.weight` prefetching has been implemented with `Quant` op,
when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching
won't work
- Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`

### Does this PR introduce _any_ user-facing change?

None.

### How was this patch tested?

Tested on `Qwen3-235B-A22B-W8A8`
<img width="1868" height="109" alt="image"
src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36"
/>


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
Ruri
2025-10-23 10:07:37 +08:00
committed by GitHub
parent 72695c97d0
commit dd7a25063c
2 changed files with 29 additions and 6 deletions

View File

@@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = num_hidden_layers
mock_forward_context.fusion_linear = "gate_up_dense"
mock_forward_context.weight_prefetch_method = None
# Ensure fusion and layer_idx increment are handled correctly
x = torch.randn(4, 8, dtype=torch.float16)
@@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase):
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 1
assert mock_get_forward_context.call_count == 2
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 1
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 2
assert mock_get_forward_context.call_count == 4
assert mock_forward_context.fusion_linear == "gate_up_dense"
assert mock_forward_context.layer_idx == 1
@@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.fusion_linear = "gate_moe"
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 3
assert mock_get_forward_context.call_count == 6
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
assert mock_forward_context.fusion_linear == fusion_linear_expected
assert mock_forward_context.layer_idx == 2
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 4
assert mock_get_forward_context.call_count == 7
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
assert mock_forward_context.fusion_linear == fusion_linear_expected
assert mock_forward_context.layer_idx == 2
@@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase):
# last layer returned directly
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 5
assert mock_get_forward_context.call_count == 8
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 6
assert mock_get_forward_context.call_count == 9
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3