diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 1254f3a..c50894d 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -173,6 +173,7 @@ jobs: if: ${{ inputs.type == 'full' }} run: | pytest -sv tests/e2e/multicard/test_data_parallel.py + pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_expert_parallel.py # external_launcher test is not stable enough. Fix it later # pytest -sv tests/e2e/multicard/test_external_launcher.py diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py new file mode 100644 index 0000000..6105ef7 --- /dev/null +++ b/tests/e2e/multicard/test_full_graph_mode.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/e2e/multicard/test_qwen3_moe.py`. +""" + +import os + +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + ('Solve the following math problem step by step.' + 'The last line of your response should be of the form Answer: ' + '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' + 'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$' + 'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,' + '$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.' + 'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,' + 'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.' + ), + ('Solve the following math problem step by step.' + 'The last line of your response should be of the form Answer: ' + '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' + 'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen' + 'independently and uniformly at random on the perimeter of $ABCD$.' + 'If the expected value of the area of triangle $\\triangle AXY$' + 'can be expressed as $\\frac{m}{n}$, for relatively prime positive' + 'integers $m$ and $n$, compute $m+n$.'), + ('Solve the following math problem step by step.' + 'The last line of your response should be of the form Answer: ' + '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' + 'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$' + 'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$' + 'and $x^2 + cx + b = 0$ also have a common real root.' + 'Compute the sum $a + b + c$.') + ] + model = "Qwen/Qwen3-30B-A3B" + sampling_params = SamplingParams(max_tokens=5, + n=1, + temperature=0.0, + top_p=1.0, + top_k=1) + with VllmRunner(model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + gpu_memory_utilization=0.95, + compilation_config={ + "cudagraph_capture_sizes": + [4, 8, 12, 16, 24, 32, 40, 48], + "cudagraph_mode": "FULL_DECODE_ONLY" + }) as runner: + vllm_fullgraph_outputs = runner.model.generate(prompts, + sampling_params) + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=True, + gpu_memory_utilization=0.95, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_fullgraph_outputs_list = [] + for output in vllm_fullgraph_outputs: + vllm_fullgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_fullgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_fullgraph_outputs", + ) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index d553637..678b0bb 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -405,6 +405,109 @@ class TestAscendAttentionBackendImpl(TestBase): mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('vllm_ascend.attention.attention_v1.get_graph_params') + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention') + @patch('torch.npu.graph_task_group_end') + @patch('torch.npu.graph_task_group_begin') + @patch('torch.npu.ExternalEvent') + @patch('torch_npu.npu.current_stream') + def test_paged_attention_with_existing_workspace( + self, + mock_get_forward_context, + mock_get_graph_params, + mock_npu_reshape_and_cache, + mock_paged_attention, + mock_graph_begin, + mock_graph_end, + mock_external_event_class, + mock_current_stream, + ): + graph_params = MagicMock() + attn_metadata = MagicMock() + num_tokens = 10 + + graph_params.workspaces = {num_tokens: 10} + graph_params.events = {num_tokens: []} + graph_params.attn_params = {num_tokens: []} + graph_params.handles = {num_tokens: []} + + query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size] + key_cache = MagicMock() + value_cache = MagicMock() + num_kv_heads = 4 + num_heads = 8 + scale = 0.1 + output = torch.randn(2, 5, 8) + + self_obj = MagicMock() + self_obj.key_cache = key_cache + self_obj.value_cache = value_cache + self_obj.num_kv_heads = num_kv_heads + self_obj.num_heads = num_heads + self_obj.scale = scale + + mock_stream = MagicMock() + mock_current_stream.return_value = mock_stream + mock_event_instance = MagicMock() + mock_external_event_class.return_value = mock_event_instance + + mock_handle = MagicMock() + mock_graph_end.return_value = mock_handle + + workspace = graph_params.workspaces.get(num_tokens) + self.assertEqual(workspace, 10) + + # 2. Handle graph capturing mode + stream = mock_current_stream() + event = mock_external_event_class() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append(( + query, + self_obj.key_cache, + self_obj.value_cache, + self_obj.num_kv_heads, + self_obj.num_heads, + self_obj.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + output, + )) + + mock_event_instance.wait.assert_called_once_with(mock_stream) + mock_event_instance.reset.assert_called_once_with(mock_stream) + self.assertEqual(len(graph_params.events[num_tokens]), 1) + self.assertEqual(len(graph_params.attn_params[num_tokens]), 1) + + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + mock_get_forward_context.return_value = MagicMock(capturing=True) + mock_get_graph_params.return_value = graph_params + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + self.assertEqual(len(graph_params.handles[num_tokens]), 0) + @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d289bb4..331e5fa 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, wait_for_kv_layer_from_connector) -from vllm_ascend.compilation.acl_graph import get_graph_params +from vllm_ascend.compilation.acl_graph import (get_graph_params, + update_graph_params_workspaces) from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) @@ -393,13 +394,28 @@ class AscendAttentionBackendImpl(AttentionImpl): forward_context: ForwardContext = get_forward_context() num_tokens = query.shape[0] if forward_context.capturing: + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + update_graph_params_workspaces(num_tokens, workspace) + + # Handle graph capturing mode stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() event.wait(stream) event.reset(stream) graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( query, self.key_cache, @@ -413,6 +429,7 @@ class AscendAttentionBackendImpl(AttentionImpl): )) torch.npu.graph_task_group_begin(stream) + torch_npu._npu_paged_attention( query=query, key_cache=self.key_cache, @@ -422,7 +439,8 @@ class AscendAttentionBackendImpl(AttentionImpl): scale_value=self.scale, block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, - out=output) + out=output, + workspace=workspace) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 8a41807..116d382 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -215,15 +215,17 @@ def update_attn_params(update_stream, forward_context, runtime_shape): with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output) + torch_npu._npu_paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=graph_params.workspaces.get(runtime_shape)) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -256,5 +258,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) +def update_graph_params_workspaces(num_tokens: int, workspace: int): + global _graph_params + if _graph_params is not None: + _graph_params.workspaces[num_tokens] = workspace + + def get_graph_params(): return _graph_params