diff --git a/tests/e2e/singlecard/pooling/test_embedding.py b/tests/e2e/singlecard/pooling/test_embedding.py index 50dc9ee9..e85b14d8 100644 --- a/tests/e2e/singlecard/pooling/test_embedding.py +++ b/tests/e2e/singlecard/pooling/test_embedding.py @@ -57,6 +57,32 @@ def test_embed_models_correctness(model: str): tol=1e-2, ) +def test_causal_embed_models_using_prefix_caching_correctness(): + # This test is to verify the correctness of prefix caching for embedding models. + # We compare the outputs of vLLM with and without prefix caching enabled, and check if they are close enough. + # We set the input query to be very long to make sure prefix caching is triggered. + queries = ['What is the capital of China?' * 256, 'Explain gravity'] + + model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B", local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,) + with VllmRunner( + model_name, + runner="pooling", + max_model_len=None, + cudagraph_capture_sizes=[4], + enable_prefix_caching=True, + ) as vllm_runner_using_caching: + vllm_outputs_without_caching = vllm_runner_using_caching.embed(queries) + vllm_outputs_with_caching = vllm_runner_using_caching.embed(queries) + + + check_embeddings_close( + embeddings_0_lst=vllm_outputs_without_caching, + embeddings_1_lst=vllm_outputs_with_caching, + name_0="without_caching", + name_1="with_caching", + tol=1e-2, + ) + def test_bge_m3_correctness(): queries = ['What is the capital of China?', 'Explain gravity'] diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 62ea6e8c..689084f4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -867,34 +867,17 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, _: torch.Tensor, ) -> torch.Tensor: - assert attn_metadata is not None - - if attn_metadata.causal: - # use sparse_mode 3 in causal scenario - return torch_npu.npu_fusion_attention( - query=query, - key=key, - value=value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=3, - atten_mask=attn_metadata.attn_mask, - actual_seq_qlen=attn_metadata.actual_seq_lengths_q, - actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, - )[0] - else: - # use default sparse_mode 0 in normal scenario, which means no mask works on it - return torch_npu.npu_fusion_attention( - query=query, - key=key, - value=value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - actual_seq_qlen=attn_metadata.actual_seq_lengths_q, - actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, - )[0] + # use default sparse_mode 0 in normal scenario, which means no mask works on it + return torch_npu.npu_fusion_attention( + query=query, + key=key, + value=value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + actual_seq_qlen=attn_metadata.actual_seq_lengths_q, + actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, + )[0] def reshape_and_cache( self, @@ -985,7 +968,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query, key, value, kv_cache, attn_metadata, output ) # pooling model branch - if attn_metadata.model_runner_type == "pooling": + if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal: attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output