Files
xc-llm-ascend/tests/e2e/singlecard/pooling/test_embedding.py
hucong d3de7333dc [BugFix][v0.18.0][cherry-pick] Fix embedding prefix caching for APC (#7894)
## What this PR does / why we need it?
pick-from:https://github.com/vllm-project/vllm-ascend/pull/7452
### Problem
Embedding models produce inconsistent outputs when prefix caching is
enabled vs disabled.

### Root Cause
The attention router condition was too broad:
- All `model_runner_type == "pooling"` → `_forward_encoder_attention()`
→ uses `npu_fusion_attention`
- **But `npu_fusion_attention` does NOT support prefix caching**
- Result: Numerical mismatch when KV cache is managed by prefix caching

### Solution
Refine the router condition to check causality:

**Before**: 
```
if attn_metadata.model_runner_type == "pooling":
    → npu_fusion_attention (no prefix caching support)
```

**After**: 
```
if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal:
    → npu_fusion_attention (for true encoders)
else:
    → npu_fused_infer_attention_score (prefix caching support)
```
### Changes Made

1. **Fixed router condition** (`vllm_ascend/attention/attention_v1.py`
L968)
   - Added `and not attn_metadata.causal` check
   - Effect: Non-causal embeddings now use correct operator

2. **Simplified encoder attention**
(`vllm_ascend/attention/attention_v1.py` L864-877)
   - Removed redundant causal branch (encoders never use causal mask)
   - Reduced from 34 lines to 14 lines

3. **Added test** (`tests/e2e/singlecard/pooling/test_embedding.py`)
- Validates embedding outputs with/without prefix caching are consistent
  
## Does this PR introduce _any_ user-facing change?

### Functional Changes
 **Yes** - Bug fix: Embedding models now produce consistent outputs
with prefix caching

### API Changes
 **No** - All public APIs unchanged

### Configuration Changes
 **No** - No new configuration required

### Backward Compatibility
 **Fully compatible** - Only fixes incorrect behavior

## How was this patch tested?
### New Test
Added `test_embed_models_using_prefix_caching_correctness()`:
- Tests: `Qwen3-Embedding-0.6B`
- Validates numerical consistency between runs with/without prefix
caching
- Uses long sequences to activate prefix caching
- Tolerance: 1e-2
- vLLM version: v0.18.0

Signed-off-by: underfituu <hzhucong@163.com>
2026-04-01 16:57:33 +08:00

127 lines
4.2 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
import pytest
from modelscope import snapshot_download # type: ignore[import-untyped]
import huggingface_hub
from tests.e2e.conftest import HfRunner, VllmRunner
from tests.e2e.utils import check_embeddings_close
MODELS = [
"Qwen/Qwen3-Embedding-0.6B", # lasttoken
"intfloat/multilingual-e5-small" # mean_tokens
]
@pytest.mark.parametrize("model", MODELS)
def test_embed_models_correctness(model: str):
queries = ['What is the capital of China?', 'Explain gravity']
model_name = snapshot_download(model, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,)
with VllmRunner(
model_name,
runner="pooling",
max_model_len=None,
cudagraph_capture_sizes=[4],
) as vllm_runner:
vllm_outputs = vllm_runner.embed(queries)
with HfRunner(
model_name,
dtype="float32",
is_sentence_transformer=True,
) as hf_runner:
hf_outputs = hf_runner.encode(queries)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
def test_causal_embed_models_using_prefix_caching_correctness():
# This test is to verify the correctness of prefix caching for embedding models.
# We compare the outputs of vLLM with and without prefix caching enabled, and check if they are close enough.
# We set the input query to be very long to make sure prefix caching is triggered.
queries = ['What is the capital of China?' * 256, 'Explain gravity']
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B", local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,)
with VllmRunner(
model_name,
runner="pooling",
max_model_len=None,
cudagraph_capture_sizes=[4],
enable_prefix_caching=True,
) as vllm_runner_using_caching:
vllm_outputs_without_caching = vllm_runner_using_caching.embed(queries)
vllm_outputs_with_caching = vllm_runner_using_caching.embed(queries)
check_embeddings_close(
embeddings_0_lst=vllm_outputs_without_caching,
embeddings_1_lst=vllm_outputs_with_caching,
name_0="without_caching",
name_1="with_caching",
tol=1e-2,
)
def test_bge_m3_correctness():
queries = ['What is the capital of China?', 'Explain gravity']
model_name = snapshot_download("BAAI/bge-m3", local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,)
with VllmRunner(
model_name,
runner="pooling",
cudagraph_capture_sizes=[4],
) as vllm_aclgraph_runner:
vllm_aclgraph_outputs = vllm_aclgraph_runner.embed(queries)
with VllmRunner(
model_name,
runner="pooling",
enforce_eager=True,
) as vllm_runner:
vllm_eager_outputs = vllm_runner.embed(queries)
with HfRunner(
model_name,
dtype="float32",
is_sentence_transformer=True,
) as hf_runner:
hf_outputs = hf_runner.encode(queries)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_eager_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
check_embeddings_close(
embeddings_0_lst=vllm_eager_outputs,
embeddings_1_lst=vllm_aclgraph_outputs,
name_0="eager",
name_1="aclgraph",
tol=1e-2,
)