[BugFix][v0.18.0][cherry-pick] Fix embedding prefix caching for APC (#7894)
## What this PR does / why we need it? pick-from:https://github.com/vllm-project/vllm-ascend/pull/7452 ### Problem Embedding models produce inconsistent outputs when prefix caching is enabled vs disabled. ### Root Cause The attention router condition was too broad: - All `model_runner_type == "pooling"` → `_forward_encoder_attention()` → uses `npu_fusion_attention` - **But `npu_fusion_attention` does NOT support prefix caching** - Result: Numerical mismatch when KV cache is managed by prefix caching ### Solution Refine the router condition to check causality: **Before**: ``` if attn_metadata.model_runner_type == "pooling": → npu_fusion_attention (no prefix caching support) ``` **After**: ``` if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal: → npu_fusion_attention (for true encoders) else: → npu_fused_infer_attention_score (prefix caching support) ``` ### Changes Made 1. **Fixed router condition** (`vllm_ascend/attention/attention_v1.py` L968) - Added `and not attn_metadata.causal` check - Effect: Non-causal embeddings now use correct operator 2. **Simplified encoder attention** (`vllm_ascend/attention/attention_v1.py` L864-877) - Removed redundant causal branch (encoders never use causal mask) - Reduced from 34 lines to 14 lines 3. **Added test** (`tests/e2e/singlecard/pooling/test_embedding.py`) - Validates embedding outputs with/without prefix caching are consistent ## Does this PR introduce _any_ user-facing change? ### Functional Changes ✅ **Yes** - Bug fix: Embedding models now produce consistent outputs with prefix caching ### API Changes ❌ **No** - All public APIs unchanged ### Configuration Changes ❌ **No** - No new configuration required ### Backward Compatibility ✅ **Fully compatible** - Only fixes incorrect behavior ## How was this patch tested? ### New Test Added `test_embed_models_using_prefix_caching_correctness()`: - Tests: `Qwen3-Embedding-0.6B` - Validates numerical consistency between runs with/without prefix caching - Uses long sequences to activate prefix caching - Tolerance: 1e-2 - vLLM version: v0.18.0 Signed-off-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -57,6 +57,32 @@ def test_embed_models_correctness(model: str):
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
def test_causal_embed_models_using_prefix_caching_correctness():
|
||||
# This test is to verify the correctness of prefix caching for embedding models.
|
||||
# We compare the outputs of vLLM with and without prefix caching enabled, and check if they are close enough.
|
||||
# We set the input query to be very long to make sure prefix caching is triggered.
|
||||
queries = ['What is the capital of China?' * 256, 'Explain gravity']
|
||||
|
||||
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B", local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,)
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
max_model_len=None,
|
||||
cudagraph_capture_sizes=[4],
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_runner_using_caching:
|
||||
vllm_outputs_without_caching = vllm_runner_using_caching.embed(queries)
|
||||
vllm_outputs_with_caching = vllm_runner_using_caching.embed(queries)
|
||||
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=vllm_outputs_without_caching,
|
||||
embeddings_1_lst=vllm_outputs_with_caching,
|
||||
name_0="without_caching",
|
||||
name_1="with_caching",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
def test_bge_m3_correctness():
|
||||
queries = ['What is the capital of China?', 'Explain gravity']
|
||||
|
||||
@@ -867,23 +867,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: AscendMetadata,
|
||||
_: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
|
||||
if attn_metadata.causal:
|
||||
# use sparse_mode 3 in causal scenario
|
||||
return torch_npu.npu_fusion_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
|
||||
)[0]
|
||||
else:
|
||||
# use default sparse_mode 0 in normal scenario, which means no mask works on it
|
||||
return torch_npu.npu_fusion_attention(
|
||||
query=query,
|
||||
@@ -985,7 +968,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query, key, value, kv_cache, attn_metadata, output
|
||||
)
|
||||
# pooling model branch
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal:
|
||||
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user