[bugfix](cp) replace None with zeros/inf tensor to avoid TypeError (#5837)
### What this PR does / why we need it?
When there is no kv cache in some devices, the `_compute_prefill_context
func` will return `None`, which is unexecpted. This PR replaces None
with full zeros/-inf tensors to avoid TypeError.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py -k test_models_chunked_prefill_with_empty_kvcache
```
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -636,30 +636,37 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
|
||||
prefix_chunk_output, prefix_chunk_lse = None, None
|
||||
if total_toks > 0:
|
||||
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=prefill_metadata.chunked_context.
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||
actual_chunk_seq_lengths)
|
||||
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
|
||||
lse_mask = batch_chunk_seq_mask[:, None,
|
||||
None].expand_as(prefix_chunk_lse)
|
||||
prefix_chunk_lse = torch.where(lse_mask, -torch.inf,
|
||||
prefix_chunk_lse)
|
||||
if total_toks == 0:
|
||||
return (torch.full((query.size(0), num_heads, self.head_size),
|
||||
fill_value=0,
|
||||
dtype=query.dtype,
|
||||
device=query.device),
|
||||
torch.full((query.size(0), num_heads, 1),
|
||||
fill_value=-torch.inf,
|
||||
dtype=torch.float32,
|
||||
device=query.device))
|
||||
|
||||
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=prefill_metadata.chunked_context.
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||
actual_chunk_seq_lengths)
|
||||
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
|
||||
lse_mask = batch_chunk_seq_mask[:, None,
|
||||
None].expand_as(prefix_chunk_lse)
|
||||
prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse)
|
||||
|
||||
return prefix_chunk_output, prefix_chunk_lse
|
||||
|
||||
|
||||
Reference in New Issue
Block a user