[bugfix](cp) replace None with zeros/inf tensor to avoid TypeError (#5837)

### What this PR does / why we need it?
When there is no kv cache in some devices, the `_compute_prefill_context
func` will return `None`, which is unexecpted. This PR replaces None
with full zeros/-inf tensors to avoid TypeError.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py -k test_models_chunked_prefill_with_empty_kvcache
```

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-01-14 20:57:48 +08:00
committed by GitHub
parent d450ba24c7
commit a88937f5cb
2 changed files with 82 additions and 26 deletions

View File

@@ -21,10 +21,16 @@ import random
import string import string
from unittest.mock import patch from unittest.mock import patch
import pytest
from vllm import SamplingParams from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
MODELS = [
"vllm-ascend/Qwen3-30B-A3B-W8A8",
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
]
def generate_prompts(input_len, batchsize): def generate_prompts(input_len, batchsize):
prompts = [ prompts = [
@@ -41,7 +47,9 @@ def generate_prompts(input_len, batchsize):
"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1", "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1",
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1" "VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"
}) })
def test_models_chunked_prefill_mixed_length_prompts_including_1_token(): @pytest.mark.parametrize("model", MODELS)
def test_models_chunked_prefill_mixed_length_prompts_including_1_token(
model: str):
TEST_ROPE_PARAMETERS = { TEST_ROPE_PARAMETERS = {
"rope_theta": 1000000, "rope_theta": 1000000,
"rope_type": "yarn", "rope_type": "yarn",
@@ -55,7 +63,6 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token():
] ]
sampling_params = SamplingParams(max_tokens=1, temperature=0.0) sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
model = "vllm-ascend/Qwen3-30B-A3B-W8A8"
with VllmRunner( with VllmRunner(
model, model,
enforce_eager=True, enforce_eager=True,
@@ -71,3 +78,45 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token():
hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS}, hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
) as runner: ) as runner:
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@patch.dict(
os.environ, {
"HCCL_BUFFSIZE": "768",
"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1",
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"
})
@pytest.mark.parametrize("model", MODELS)
def test_models_chunked_prefill_with_empty_kvcache(model: str):
TEST_ROPE_PARAMETERS = {
"rope_theta": 1000000,
"rope_type": "yarn",
"factor": 4,
"original_max_position_embeddings": 32768
}
# Note(qcs): we use chunk_size=50, kv_cache_interleave_size=128
# to simulate certain edge cases.
prompts = [
generate_prompts(128, 1)[0],
generate_prompts(1, 1)[0],
generate_prompts(130, 1)[0],
generate_prompts(51, 1)[0],
generate_prompts(129, 1)[0],
]
sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
with VllmRunner(
model,
enforce_eager=True,
max_num_seqs=2,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=1,
enable_expert_parallel=True,
long_prefill_token_threshold=50,
block_size=128,
cp_kv_cache_interleave_size=128,
quantization="ascend",
hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
) as runner:
runner.model.generate(prompts, sampling_params)

View File

@@ -636,30 +636,37 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
else: else:
num_heads = self.num_heads num_heads = self.num_heads
prefix_chunk_output, prefix_chunk_lse = None, None if total_toks == 0:
if total_toks > 0: return (torch.full((query.size(0), num_heads, self.head_size),
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( fill_value=0,
query, dtype=query.dtype,
key, device=query.device),
value, torch.full((query.size(0), num_heads, 1),
num_heads=num_heads, fill_value=-torch.inf,
num_key_value_heads=self.num_kv_heads, dtype=torch.float32,
input_layout="TND", device=query.device))
atten_mask=None,
scale=self.scale, prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
sparse_mode=0, query,
antiquant_mode=0, key,
antiquant_scale=None, value,
softmax_lse_flag=True, num_heads=num_heads,
actual_seq_lengths_kv=prefill_metadata.chunked_context. num_key_value_heads=self.num_kv_heads,
actual_seq_lengths_kv, input_layout="TND",
actual_seq_lengths=attn_metadata.prefill.chunked_context. atten_mask=None,
actual_chunk_seq_lengths) scale=self.scale,
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask sparse_mode=0,
lse_mask = batch_chunk_seq_mask[:, None, antiquant_mode=0,
None].expand_as(prefix_chunk_lse) antiquant_scale=None,
prefix_chunk_lse = torch.where(lse_mask, -torch.inf, softmax_lse_flag=True,
prefix_chunk_lse) actual_seq_lengths_kv=prefill_metadata.chunked_context.
actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
lse_mask = batch_chunk_seq_mask[:, None,
None].expand_as(prefix_chunk_lse)
prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse)
return prefix_chunk_output, prefix_chunk_lse return prefix_chunk_output, prefix_chunk_lse