[BugFix]converting pa get_workspace back to capturing (#5833)
### What this PR does / why we need it?
This helps to fix a bug in for pa get_workspace. In earlier
implementation, we use `_npu_paged_attention_get_workspace` in
`_update_pa_attn_params`. However, this might cause some potential
memory problems as it dynamically allocate new memory for workspace when
calling this api. Therefor, we move this back to capturing, and use a
fixed `SEQ_LEN_WITH_MAX_PA_WORKSPACE` to get max workspace.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -229,25 +229,6 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
||||
# calculate max workspace in capturing. So additional get_workspace is added
|
||||
# here to avoid such bugs.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
@@ -259,7 +240,7 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user