[Feat] Prefetching Attention QKV Linear Weight With AddRmsNormQuant Custom Op (#3517)
### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant` ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
@@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mock_forward_context.layer_idx = 0
|
||||
mock_forward_context.num_hidden_layers = num_hidden_layers
|
||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||
mock_forward_context.weight_prefetch_method = None
|
||||
|
||||
# Ensure fusion and layer_idx increment are handled correctly
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
@@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 1
|
||||
assert mock_get_forward_context.call_count == 2
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 2
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
@@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mock_forward_context.fusion_linear = "gate_moe"
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 3
|
||||
assert mock_get_forward_context.call_count == 6
|
||||
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
|
||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_get_forward_context.call_count == 7
|
||||
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
|
||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
@@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase):
|
||||
# last layer returned directly
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 5
|
||||
assert mock_get_forward_context.call_count == 8
|
||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||
assert mock_forward_context.layer_idx == 3
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 6
|
||||
assert mock_get_forward_context.call_count == 9
|
||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||
assert mock_forward_context.layer_idx == 3
|
||||
|
||||
|
||||
Reference in New Issue
Block a user