[BugFix]converting pa get_workspace back to capturing (#5833)

### What this PR does / why we need it?

This helps to fix a bug in for pa get_workspace. In earlier
implementation, we use `_npu_paged_attention_get_workspace` in
`_update_pa_attn_params`. However, this might cause some potential
memory problems as it dynamically allocate new memory for workspace when
calling this api. Therefor, we move this back to capturing, and use a
fixed `SEQ_LEN_WITH_MAX_PA_WORKSPACE` to get max workspace.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-01-22 15:49:22 +08:00
committed by GitHub
parent 484e7c59dc
commit 1d3544c887
2 changed files with 13 additions and 22 deletions

View File

@@ -76,7 +76,7 @@ from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
# yapf conflicts with isort for this block
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
@@ -133,6 +133,9 @@ if get_ascend_device_type() == AscendDeviceType._310P:
torch_npu.npu.set_compile_mode(jit_compile=False)
SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144
@dataclass
class GraphCaptureContext:
stream: torch.npu.Stream
@@ -1919,6 +1922,7 @@ class NPUModelRunner(GPUModelRunner):
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
is_graph_capturing: bool = False,
) -> Optional[dict[str, Any]]:
attn_metadata: Optional[dict[str, Any]] = None
@@ -1928,7 +1932,12 @@ class NPUModelRunner(GPUModelRunner):
attn_metadata = {}
seq_lens = max_query_len
# The reason why we use a fixed seq_len rather than max_query_len is that
# _npu_paged_attention_get_workspace only returns max workspace with specific
# seq_lens. We use this seq_len only when capturing graph, and still use max_query_len
# in inference. This will be removed once npu_fused_infer_attention_score
# outperforms _npu_paged_attention on all cases.
seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
@@ -2177,6 +2186,7 @@ class NPUModelRunner(GPUModelRunner):
max_query_len=max_query_len,
aclgraph_runtime_mode=cudagraph_runtime_mode,
force_attention=force_attention,
is_graph_capturing=is_graph_capturing,
num_scheduled_tokens=num_scheduled_tokens,
)