From a88937f5cb564ace26135abfd400b9a57919b2d3 Mon Sep 17 00:00:00 2001 From: Qiu Date: Wed, 14 Jan 2026 20:57:48 +0800 Subject: [PATCH] [bugfix](cp) replace None with zeros/inf tensor to avoid TypeError (#5837) ### What this PR does / why we need it? When there is no kv cache in some devices, the `_compute_prefill_context func` will return `None`, which is unexecpted. This PR replaces None with full zeros/-inf tensors to avoid TypeError. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ```bash pytest tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py -k test_models_chunked_prefill_with_empty_kvcache ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: QiuChunshuo --- .../long_sequence/test_chunked_prefill.py | 53 +++++++++++++++++- .../context_parallel/attention_cp.py | 55 +++++++++++-------- 2 files changed, 82 insertions(+), 26 deletions(-) diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py b/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py index b2536460..3021c9d5 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py @@ -21,10 +21,16 @@ import random import string from unittest.mock import patch +import pytest from vllm import SamplingParams from tests.e2e.conftest import VllmRunner +MODELS = [ + "vllm-ascend/Qwen3-30B-A3B-W8A8", + "vllm-ascend/DeepSeek-V2-Lite-W8A8", +] + def generate_prompts(input_len, batchsize): prompts = [ @@ -41,7 +47,9 @@ def generate_prompts(input_len, batchsize): "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1", "VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1" }) -def test_models_chunked_prefill_mixed_length_prompts_including_1_token(): +@pytest.mark.parametrize("model", MODELS) +def test_models_chunked_prefill_mixed_length_prompts_including_1_token( + model: str): TEST_ROPE_PARAMETERS = { "rope_theta": 1000000, "rope_type": "yarn", @@ -55,7 +63,6 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token(): ] sampling_params = SamplingParams(max_tokens=1, temperature=0.0) - model = "vllm-ascend/Qwen3-30B-A3B-W8A8" with VllmRunner( model, enforce_eager=True, @@ -71,3 +78,45 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token(): hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS}, ) as runner: runner.model.generate(prompts, sampling_params) + + +@patch.dict( + os.environ, { + "HCCL_BUFFSIZE": "768", + "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1", + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1" + }) +@pytest.mark.parametrize("model", MODELS) +def test_models_chunked_prefill_with_empty_kvcache(model: str): + TEST_ROPE_PARAMETERS = { + "rope_theta": 1000000, + "rope_type": "yarn", + "factor": 4, + "original_max_position_embeddings": 32768 + } + # Note(qcs): we use chunk_size=50, kv_cache_interleave_size=128 + # to simulate certain edge cases. + prompts = [ + generate_prompts(128, 1)[0], + generate_prompts(1, 1)[0], + generate_prompts(130, 1)[0], + generate_prompts(51, 1)[0], + generate_prompts(129, 1)[0], + ] + sampling_params = SamplingParams(max_tokens=1, temperature=0.0) + + with VllmRunner( + model, + enforce_eager=True, + max_num_seqs=2, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=1, + enable_expert_parallel=True, + long_prefill_token_threshold=50, + block_size=128, + cp_kv_cache_interleave_size=128, + quantization="ascend", + hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS}, + ) as runner: + runner.model.generate(prompts, sampling_params) diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 088cd0e4..4f0716a5 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -636,30 +636,37 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): else: num_heads = self.num_heads - prefix_chunk_output, prefix_chunk_lse = None, None - if total_toks > 0: - prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( - query, - key, - value, - num_heads=num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="TND", - atten_mask=None, - scale=self.scale, - sparse_mode=0, - antiquant_mode=0, - antiquant_scale=None, - softmax_lse_flag=True, - actual_seq_lengths_kv=prefill_metadata.chunked_context. - actual_seq_lengths_kv, - actual_seq_lengths=attn_metadata.prefill.chunked_context. - actual_chunk_seq_lengths) - batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask - lse_mask = batch_chunk_seq_mask[:, None, - None].expand_as(prefix_chunk_lse) - prefix_chunk_lse = torch.where(lse_mask, -torch.inf, - prefix_chunk_lse) + if total_toks == 0: + return (torch.full((query.size(0), num_heads, self.head_size), + fill_value=0, + dtype=query.dtype, + device=query.device), + torch.full((query.size(0), num_heads, 1), + fill_value=-torch.inf, + dtype=torch.float32, + device=query.device)) + + prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + atten_mask=None, + scale=self.scale, + sparse_mode=0, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths_kv=prefill_metadata.chunked_context. + actual_seq_lengths_kv, + actual_seq_lengths=attn_metadata.prefill.chunked_context. + actual_chunk_seq_lengths) + batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask + lse_mask = batch_chunk_seq_mask[:, None, + None].expand_as(prefix_chunk_lse) + prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse) return prefix_chunk_output, prefix_chunk_lse