[bugfix](cp) replace None with zeros/inf tensor to avoid TypeError (#5837)

### What this PR does / why we need it?
When there is no kv cache in some devices, the `_compute_prefill_context
func` will return `None`, which is unexecpted. This PR replaces None
with full zeros/-inf tensors to avoid TypeError.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py -k test_models_chunked_prefill_with_empty_kvcache
```

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-01-14 20:57:48 +08:00
committed by GitHub
parent d450ba24c7
commit a88937f5cb
2 changed files with 82 additions and 26 deletions

View File

@@ -636,30 +636,37 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
else:
num_heads = self.num_heads
prefix_chunk_output, prefix_chunk_lse = None, None
if total_toks > 0:
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=prefill_metadata.chunked_context.
actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
lse_mask = batch_chunk_seq_mask[:, None,
None].expand_as(prefix_chunk_lse)
prefix_chunk_lse = torch.where(lse_mask, -torch.inf,
prefix_chunk_lse)
if total_toks == 0:
return (torch.full((query.size(0), num_heads, self.head_size),
fill_value=0,
dtype=query.dtype,
device=query.device),
torch.full((query.size(0), num_heads, 1),
fill_value=-torch.inf,
dtype=torch.float32,
device=query.device))
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=prefill_metadata.chunked_context.
actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
lse_mask = batch_chunk_seq_mask[:, None,
None].expand_as(prefix_chunk_lse)
prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse)
return prefix_chunk_output, prefix_chunk_lse