diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index ed9dd44..7a829b6 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -117,6 +117,7 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.layer_idx = 0 mock_forward_context.num_hidden_layers = num_hidden_layers mock_forward_context.fusion_linear = "gate_up_dense" + mock_forward_context.weight_prefetch_method = None # Ensure fusion and layer_idx increment are handled correctly x = torch.randn(4, 8, dtype=torch.float16) @@ -125,13 +126,13 @@ class TestAscendRMSNorm(PytestBase): x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 1 + assert mock_get_forward_context.call_count == 2 assert mock_forward_context.fusion_linear == "qkv_dense" assert mock_forward_context.layer_idx == 1 x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 2 + assert mock_get_forward_context.call_count == 4 assert mock_forward_context.fusion_linear == "gate_up_dense" assert mock_forward_context.layer_idx == 1 @@ -139,14 +140,14 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.fusion_linear = "gate_moe" x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 3 + assert mock_get_forward_context.call_count == 6 fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense" assert mock_forward_context.fusion_linear == fusion_linear_expected assert mock_forward_context.layer_idx == 2 x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 4 + assert mock_get_forward_context.call_count == 7 fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense" assert mock_forward_context.fusion_linear == fusion_linear_expected assert mock_forward_context.layer_idx == 2 @@ -156,13 +157,13 @@ class TestAscendRMSNorm(PytestBase): # last layer returned directly x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 5 + assert mock_get_forward_context.call_count == 8 assert mock_forward_context.fusion_linear == "qkv_moe" assert mock_forward_context.layer_idx == 3 x_out, residual_out = layer.forward_oot(x, residual) - assert mock_get_forward_context.call_count == 6 + assert mock_get_forward_context.call_count == 9 assert mock_forward_context.fusion_linear == "qkv_moe" assert mock_forward_context.layer_idx == 3 diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 55eab21..239eeb0 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot( torch_npu_check = version_check() if layer is not None and not is_310p(): + layer_cls_name = layer.__class__.__name__ + try: + weight_prefetch_method = get_forward_context( + ).weight_prefetch_method + except AssertionError: + weight_prefetch_method = None + + # prefetch qkvo_proj.weight preprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) + # add_rms_norm_quant if torch_npu_check: x, _, residual = torch_npu.npu_add_rms_norm_quant( x, @@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot( layer.aclnn_input_scale, layer.aclnn_input_offset, epsilon=self.variance_epsilon) + # prefetch qkvo_proj.weight postprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + layer_cls_name=layer_cls_name, + stop_flag=x, + ) + else: if is_310p(): orig_dtype = residual.dtype