diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index f2cab32..d3e779e 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -212,19 +212,37 @@ def update_attn_params(update_stream, forward_context, runtime_shape): ) = param seq_lens = forward_context.attn_metadata[key].seq_lens + # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY + # mode with GQA. This is triggered by getting workspace for _npu_paged_attention + # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens + # might encounter a bigger workspace, while currently we use max_model_len to + # calculate max workspace in capturing. So additional get_workspace is added + # here to avoid such bugs. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=graph_params.workspaces.get(runtime_shape)) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) torch.npu.graph_task_update_end(update_stream) event.record(update_stream)