[Attention] Temporarily add back pa for small batch sizes. (#4765)

### What this PR does / why we need it?
This PR adds back pa in scenarios of small batch sizes due to
performance consideration. Will remove pa once fia performs better than
pa in all scenarios.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
whx
2025-12-15 20:35:50 +08:00
committed by GitHub
parent 95e6400128
commit a9625851ef
4 changed files with 163 additions and 6 deletions

View File

@@ -1,13 +1,29 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional
import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm_ascend.utils import get_ascend_config
@lru_cache
def using_paged_attention(runtime_shape: int) -> bool:
vllm_config = get_current_vllm_config()
if vllm_config.speculative_config is not None:
return False
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
return False
return runtime_shape in get_ascend_config().pa_shape_list
@dataclass
# class AscendCommonLongSequenceMetadata: