[Feat] Prefetching Attention QKV Linear Weight With AddRmsNormQuant Custom Op (#3517)
### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant` ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
@@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mock_forward_context.layer_idx = 0
|
||||
mock_forward_context.num_hidden_layers = num_hidden_layers
|
||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||
mock_forward_context.weight_prefetch_method = None
|
||||
|
||||
# Ensure fusion and layer_idx increment are handled correctly
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
@@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 1
|
||||
assert mock_get_forward_context.call_count == 2
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 2
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
@@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mock_forward_context.fusion_linear = "gate_moe"
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 3
|
||||
assert mock_get_forward_context.call_count == 6
|
||||
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
|
||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_get_forward_context.call_count == 7
|
||||
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
|
||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
@@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase):
|
||||
# last layer returned directly
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 5
|
||||
assert mock_get_forward_context.call_count == 8
|
||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||
assert mock_forward_context.layer_idx == 3
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 6
|
||||
assert mock_get_forward_context.call_count == 9
|
||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||
assert mock_forward_context.layer_idx == 3
|
||||
|
||||
|
||||
@@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot(
|
||||
|
||||
torch_npu_check = version_check()
|
||||
if layer is not None and not is_310p():
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
try:
|
||||
weight_prefetch_method = get_forward_context(
|
||||
).weight_prefetch_method
|
||||
except AssertionError:
|
||||
weight_prefetch_method = None
|
||||
|
||||
# prefetch qkvo_proj.weight preprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# add_rms_norm_quant
|
||||
if torch_npu_check:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
@@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot(
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
# prefetch qkvo_proj.weight postprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
stop_flag=x,
|
||||
)
|
||||
|
||||
else:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
|
||||
Reference in New Issue
Block a user