[Feat] Prefetching Attention QKV Linear Weight With AddRmsNormQuant Custom Op (#3517)
### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant` ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
@@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
mock_forward_context.layer_idx = 0
|
mock_forward_context.layer_idx = 0
|
||||||
mock_forward_context.num_hidden_layers = num_hidden_layers
|
mock_forward_context.num_hidden_layers = num_hidden_layers
|
||||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||||
|
mock_forward_context.weight_prefetch_method = None
|
||||||
|
|
||||||
# Ensure fusion and layer_idx increment are handled correctly
|
# Ensure fusion and layer_idx increment are handled correctly
|
||||||
x = torch.randn(4, 8, dtype=torch.float16)
|
x = torch.randn(4, 8, dtype=torch.float16)
|
||||||
@@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 1
|
assert mock_get_forward_context.call_count == 2
|
||||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||||
assert mock_forward_context.layer_idx == 1
|
assert mock_forward_context.layer_idx == 1
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 2
|
assert mock_get_forward_context.call_count == 4
|
||||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||||
assert mock_forward_context.layer_idx == 1
|
assert mock_forward_context.layer_idx == 1
|
||||||
|
|
||||||
@@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
mock_forward_context.fusion_linear = "gate_moe"
|
mock_forward_context.fusion_linear = "gate_moe"
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 3
|
assert mock_get_forward_context.call_count == 6
|
||||||
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
|
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
|
||||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||||
assert mock_forward_context.layer_idx == 2
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 4
|
assert mock_get_forward_context.call_count == 7
|
||||||
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
|
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
|
||||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||||
assert mock_forward_context.layer_idx == 2
|
assert mock_forward_context.layer_idx == 2
|
||||||
@@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
# last layer returned directly
|
# last layer returned directly
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 5
|
assert mock_get_forward_context.call_count == 8
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||||
assert mock_forward_context.layer_idx == 3
|
assert mock_forward_context.layer_idx == 3
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 6
|
assert mock_get_forward_context.call_count == 9
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||||
assert mock_forward_context.layer_idx == 3
|
assert mock_forward_context.layer_idx == 3
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot(
|
|||||||
|
|
||||||
torch_npu_check = version_check()
|
torch_npu_check = version_check()
|
||||||
if layer is not None and not is_310p():
|
if layer is not None and not is_310p():
|
||||||
|
layer_cls_name = layer.__class__.__name__
|
||||||
|
try:
|
||||||
|
weight_prefetch_method = get_forward_context(
|
||||||
|
).weight_prefetch_method
|
||||||
|
except AssertionError:
|
||||||
|
weight_prefetch_method = None
|
||||||
|
|
||||||
|
# prefetch qkvo_proj.weight preprocess
|
||||||
|
if weight_prefetch_method:
|
||||||
|
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||||
|
layer_cls_name=layer_cls_name,
|
||||||
|
weight=layer.weight,
|
||||||
|
start_flag=x,
|
||||||
|
)
|
||||||
|
# add_rms_norm_quant
|
||||||
if torch_npu_check:
|
if torch_npu_check:
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||||
x,
|
x,
|
||||||
@@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot(
|
|||||||
layer.aclnn_input_scale,
|
layer.aclnn_input_scale,
|
||||||
layer.aclnn_input_offset,
|
layer.aclnn_input_offset,
|
||||||
epsilon=self.variance_epsilon)
|
epsilon=self.variance_epsilon)
|
||||||
|
# prefetch qkvo_proj.weight postprocess
|
||||||
|
if weight_prefetch_method:
|
||||||
|
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||||
|
layer_cls_name=layer_cls_name,
|
||||||
|
stop_flag=x,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if is_310p():
|
if is_310p():
|
||||||
orig_dtype = residual.dtype
|
orig_dtype = residual.dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user