add pagedattention to support FULL_DECODE_ONLY. (#3102)
### What this PR does / why we need it? Calculate in advance the workspace memory size needed for the PagedAttention operator to avoid deadlocks during resource cleanup. This PR requires torch_npu version 0920 or newer. ### How was this patch tested? - vLLM version: v0.11.0 --------- Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -173,6 +173,7 @@ jobs:
|
||||
if: ${{ inputs.type == 'full' }}
|
||||
run: |
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
|
||||
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
||||
# external_launcher test is not stable enough. Fix it later
|
||||
# pytest -sv tests/e2e/multicard/test_external_launcher.py
|
||||
|
||||
103
tests/e2e/multicard/test_full_graph_mode.py
Normal file
103
tests/e2e/multicard/test_full_graph_mode.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
#
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$'
|
||||
'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,'
|
||||
'$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.'
|
||||
'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,'
|
||||
'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.'
|
||||
),
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen'
|
||||
'independently and uniformly at random on the perimeter of $ABCD$.'
|
||||
'If the expected value of the area of triangle $\\triangle AXY$'
|
||||
'can be expressed as $\\frac{m}{n}$, for relatively prime positive'
|
||||
'integers $m$ and $n$, compute $m+n$.'),
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$'
|
||||
'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$'
|
||||
'and $x^2 + cx + b = 0$ also have a common real root.'
|
||||
'Compute the sum $a + b + c$.')
|
||||
]
|
||||
model = "Qwen/Qwen3-30B-A3B"
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
top_k=1)
|
||||
with VllmRunner(model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.95,
|
||||
compilation_config={
|
||||
"cudagraph_capture_sizes":
|
||||
[4, 8, 12, 16, 24, 32, 40, 48],
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY"
|
||||
}) as runner:
|
||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||
sampling_params)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.95,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
vllm_fullgraph_outputs_list = []
|
||||
for output in vllm_fullgraph_outputs:
|
||||
vllm_fullgraph_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_fullgraph_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_fullgraph_outputs",
|
||||
)
|
||||
@@ -405,6 +405,109 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch.npu.graph_task_group_end')
|
||||
@patch('torch.npu.graph_task_group_begin')
|
||||
@patch('torch.npu.ExternalEvent')
|
||||
@patch('torch_npu.npu.current_stream')
|
||||
def test_paged_attention_with_existing_workspace(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_get_graph_params,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention,
|
||||
mock_graph_begin,
|
||||
mock_graph_end,
|
||||
mock_external_event_class,
|
||||
mock_current_stream,
|
||||
):
|
||||
graph_params = MagicMock()
|
||||
attn_metadata = MagicMock()
|
||||
num_tokens = 10
|
||||
|
||||
graph_params.workspaces = {num_tokens: 10}
|
||||
graph_params.events = {num_tokens: []}
|
||||
graph_params.attn_params = {num_tokens: []}
|
||||
graph_params.handles = {num_tokens: []}
|
||||
|
||||
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
|
||||
key_cache = MagicMock()
|
||||
value_cache = MagicMock()
|
||||
num_kv_heads = 4
|
||||
num_heads = 8
|
||||
scale = 0.1
|
||||
output = torch.randn(2, 5, 8)
|
||||
|
||||
self_obj = MagicMock()
|
||||
self_obj.key_cache = key_cache
|
||||
self_obj.value_cache = value_cache
|
||||
self_obj.num_kv_heads = num_kv_heads
|
||||
self_obj.num_heads = num_heads
|
||||
self_obj.scale = scale
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_current_stream.return_value = mock_stream
|
||||
mock_event_instance = MagicMock()
|
||||
mock_external_event_class.return_value = mock_event_instance
|
||||
|
||||
mock_handle = MagicMock()
|
||||
mock_graph_end.return_value = mock_handle
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
self.assertEqual(workspace, 10)
|
||||
|
||||
# 2. Handle graph capturing mode
|
||||
stream = mock_current_stream()
|
||||
event = mock_external_event_class()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
query,
|
||||
self_obj.key_cache,
|
||||
self_obj.value_cache,
|
||||
self_obj.num_kv_heads,
|
||||
self_obj.num_heads,
|
||||
self_obj.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
output,
|
||||
))
|
||||
|
||||
mock_event_instance.wait.assert_called_once_with(mock_stream)
|
||||
mock_event_instance.reset.assert_called_once_with(mock_stream)
|
||||
self.assertEqual(len(graph_params.events[num_tokens]), 1)
|
||||
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
|
||||
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
||||
mock_get_graph_params.return_value = graph_params
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||
|
||||
@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
@@ -393,13 +394,28 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
query,
|
||||
self.key_cache,
|
||||
@@ -413,6 +429,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
@@ -422,7 +439,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
|
||||
@@ -215,15 +215,17 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=graph_params.workspaces.get(runtime_shape))
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -256,5 +258,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
Reference in New Issue
Block a user