From d3de7333dc82563055031a502c56c821955cd7a9 Mon Sep 17 00:00:00 2001 From: hucong <33891520+underfituu@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:57:33 +0800 Subject: [PATCH] [BugFix][v0.18.0][cherry-pick] Fix embedding prefix caching for APC (#7894) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What this PR does / why we need it? pick-from:https://github.com/vllm-project/vllm-ascend/pull/7452 ### Problem Embedding models produce inconsistent outputs when prefix caching is enabled vs disabled. ### Root Cause The attention router condition was too broad: - All `model_runner_type == "pooling"` → `_forward_encoder_attention()` → uses `npu_fusion_attention` - **But `npu_fusion_attention` does NOT support prefix caching** - Result: Numerical mismatch when KV cache is managed by prefix caching ### Solution Refine the router condition to check causality: **Before**: ``` if attn_metadata.model_runner_type == "pooling": → npu_fusion_attention (no prefix caching support) ``` **After**: ``` if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal: → npu_fusion_attention (for true encoders) else: → npu_fused_infer_attention_score (prefix caching support) ``` ### Changes Made 1. **Fixed router condition** (`vllm_ascend/attention/attention_v1.py` L968) - Added `and not attn_metadata.causal` check - Effect: Non-causal embeddings now use correct operator 2. **Simplified encoder attention** (`vllm_ascend/attention/attention_v1.py` L864-877) - Removed redundant causal branch (encoders never use causal mask) - Reduced from 34 lines to 14 lines 3. **Added test** (`tests/e2e/singlecard/pooling/test_embedding.py`) - Validates embedding outputs with/without prefix caching are consistent ## Does this PR introduce _any_ user-facing change? ### Functional Changes ✅ **Yes** - Bug fix: Embedding models now produce consistent outputs with prefix caching ### API Changes ❌ **No** - All public APIs unchanged ### Configuration Changes ❌ **No** - No new configuration required ### Backward Compatibility ✅ **Fully compatible** - Only fixes incorrect behavior ## How was this patch tested? ### New Test Added `test_embed_models_using_prefix_caching_correctness()`: - Tests: `Qwen3-Embedding-0.6B` - Validates numerical consistency between runs with/without prefix caching - Uses long sequences to activate prefix caching - Tolerance: 1e-2 - vLLM version: v0.18.0 Signed-off-by: underfituu --- .../e2e/singlecard/pooling/test_embedding.py | 26 ++++++++++++ vllm_ascend/attention/attention_v1.py | 41 ++++++------------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/tests/e2e/singlecard/pooling/test_embedding.py b/tests/e2e/singlecard/pooling/test_embedding.py index 50dc9ee9..e85b14d8 100644 --- a/tests/e2e/singlecard/pooling/test_embedding.py +++ b/tests/e2e/singlecard/pooling/test_embedding.py @@ -57,6 +57,32 @@ def test_embed_models_correctness(model: str): tol=1e-2, ) +def test_causal_embed_models_using_prefix_caching_correctness(): + # This test is to verify the correctness of prefix caching for embedding models. + # We compare the outputs of vLLM with and without prefix caching enabled, and check if they are close enough. + # We set the input query to be very long to make sure prefix caching is triggered. + queries = ['What is the capital of China?' * 256, 'Explain gravity'] + + model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B", local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,) + with VllmRunner( + model_name, + runner="pooling", + max_model_len=None, + cudagraph_capture_sizes=[4], + enable_prefix_caching=True, + ) as vllm_runner_using_caching: + vllm_outputs_without_caching = vllm_runner_using_caching.embed(queries) + vllm_outputs_with_caching = vllm_runner_using_caching.embed(queries) + + + check_embeddings_close( + embeddings_0_lst=vllm_outputs_without_caching, + embeddings_1_lst=vllm_outputs_with_caching, + name_0="without_caching", + name_1="with_caching", + tol=1e-2, + ) + def test_bge_m3_correctness(): queries = ['What is the capital of China?', 'Explain gravity'] diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 62ea6e8c..689084f4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -867,34 +867,17 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, _: torch.Tensor, ) -> torch.Tensor: - assert attn_metadata is not None - - if attn_metadata.causal: - # use sparse_mode 3 in causal scenario - return torch_npu.npu_fusion_attention( - query=query, - key=key, - value=value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=3, - atten_mask=attn_metadata.attn_mask, - actual_seq_qlen=attn_metadata.actual_seq_lengths_q, - actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, - )[0] - else: - # use default sparse_mode 0 in normal scenario, which means no mask works on it - return torch_npu.npu_fusion_attention( - query=query, - key=key, - value=value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - actual_seq_qlen=attn_metadata.actual_seq_lengths_q, - actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, - )[0] + # use default sparse_mode 0 in normal scenario, which means no mask works on it + return torch_npu.npu_fusion_attention( + query=query, + key=key, + value=value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + actual_seq_qlen=attn_metadata.actual_seq_lengths_q, + actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, + )[0] def reshape_and_cache( self, @@ -985,7 +968,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query, key, value, kv_cache, attn_metadata, output ) # pooling model branch - if attn_metadata.model_runner_type == "pooling": + if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal: attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output