[BugFix]converting pa get_workspace back to capturing (#5833)
### What this PR does / why we need it?
This helps to fix a bug in for pa get_workspace. In earlier
implementation, we use `_npu_paged_attention_get_workspace` in
`_update_pa_attn_params`. However, this might cause some potential
memory problems as it dynamically allocate new memory for workspace when
calling this api. Therefor, we move this back to capturing, and use a
fixed `SEQ_LEN_WITH_MAX_PA_WORKSPACE` to get max workspace.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -229,25 +229,6 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
|||||||
) = param
|
) = param
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||||
|
|
||||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
|
||||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
|
||||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
|
||||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
|
||||||
# calculate max workspace in capturing. So additional get_workspace is added
|
|
||||||
# here to avoid such bugs.
|
|
||||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
|
||||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
|
||||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
||||||
query=query,
|
|
||||||
key_cache=key_cache,
|
|
||||||
value_cache=value_cache,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
num_heads=num_heads,
|
|
||||||
scale_value=scale,
|
|
||||||
block_table=block_table,
|
|
||||||
context_lens=seq_lens,
|
|
||||||
out=output,
|
|
||||||
)
|
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -259,7 +240,7 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
|||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
context_lens=seq_lens,
|
context_lens=seq_lens,
|
||||||
out=output,
|
out=output,
|
||||||
workspace=workspace,
|
workspace=graph_params.workspaces.get(runtime_shape),
|
||||||
)
|
)
|
||||||
torch.npu.graph_task_update_end(update_stream)
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ from vllm.v1.worker.utils import AttentionGroup
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
@@ -133,6 +133,9 @@ if get_ascend_device_type() == AscendDeviceType._310P:
|
|||||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||||
|
|
||||||
|
|
||||||
|
SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
stream: torch.npu.Stream
|
stream: torch.npu.Stream
|
||||||
@@ -1919,6 +1922,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||||
force_attention: bool = False,
|
force_attention: bool = False,
|
||||||
|
is_graph_capturing: bool = False,
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> Optional[dict[str, Any]]:
|
||||||
attn_metadata: Optional[dict[str, Any]] = None
|
attn_metadata: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
@@ -1928,7 +1932,12 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
|
|
||||||
seq_lens = max_query_len
|
# The reason why we use a fixed seq_len rather than max_query_len is that
|
||||||
|
# _npu_paged_attention_get_workspace only returns max workspace with specific
|
||||||
|
# seq_lens. We use this seq_len only when capturing graph, and still use max_query_len
|
||||||
|
# in inference. This will be removed once npu_fused_infer_attention_score
|
||||||
|
# outperforms _npu_paged_attention on all cases.
|
||||||
|
seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len
|
||||||
self.seq_lens.np[:num_reqs] = seq_lens
|
self.seq_lens.np[:num_reqs] = seq_lens
|
||||||
self.seq_lens.np[num_reqs:] = 0
|
self.seq_lens.np[num_reqs:] = 0
|
||||||
self.seq_lens.copy_to_gpu()
|
self.seq_lens.copy_to_gpu()
|
||||||
@@ -2177,6 +2186,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
aclgraph_runtime_mode=cudagraph_runtime_mode,
|
aclgraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
force_attention=force_attention,
|
force_attention=force_attention,
|
||||||
|
is_graph_capturing=is_graph_capturing,
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user