From e38ef2c4344e11130a8fae5260b9dad6962d1fd1 Mon Sep 17 00:00:00 2001 From: XiaoxinWang <963372609@qq.com> Date: Mon, 17 Nov 2025 10:50:35 +0800 Subject: [PATCH] support FULL graph mode for GQA (#3970) ### What this PR does / why we need it? The current library only supports the FullDecodeOnly graph mode, which enables full graph execution during the decode. This PR extends support to allow full graph execution in both the prefill and decode, referred to as FULL graph mode. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- .github/workflows/_e2e_test.yaml | 1 + tests/e2e/multicard/test_full_graph_mode.py | 57 +++- .../spec_decode_v1/test_v1_mtp_correctness.py | 6 +- tests/ut/attention/test_attention_v1.py | 179 +++--------- vllm_ascend/attention/attention_v1.py | 274 +++++++++++------- vllm_ascend/attention/mla_v1.py | 6 +- vllm_ascend/attention/utils.py | 2 + vllm_ascend/compilation/acl_graph.py | 67 ++--- vllm_ascend/platform.py | 6 +- vllm_ascend/torchair/torchair_model_runner.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 24 +- 11 files changed, 328 insertions(+), 296 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 98f91533..be5b43e6 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -180,6 +180,7 @@ jobs: if: ${{ inputs.type == 'full' }} run: | pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py + pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py pytest -sv tests/e2e/multicard/test_external_launcher.py diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py index 3b9f2932..3ccbf823 100644 --- a/tests/e2e/multicard/test_full_graph_mode.py +++ b/tests/e2e/multicard/test_full_graph_mode.py @@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner from tests.e2e.model_utils import check_outputs_equal -def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY(): if 'HCCL_OP_EXPANSION_MODE' in os.environ: del os.environ['HCCL_OP_EXPANSION_MODE'] prompts = [ @@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): max_model_len=1024, tensor_parallel_size=2, enforce_eager=False, - compilation_config={"cudagraph_mode": - "FULL_DECODE_ONLY"}) as runner: + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }) as runner: vllm_fullgraph_outputs = runner.model.generate(prompts, sampling_params) with VllmRunner( model, max_model_len=1024, - enforce_eager=True, + tensor_parallel_size=2, + enforce_eager=False, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_fullgraph_outputs_list = [] + for output in vllm_fullgraph_outputs: + vllm_fullgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_fullgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_fullgraph_outputs", + ) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + model = "Qwen/Qwen3-30B-A3B" + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + with VllmRunner(model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={ + "cudagraph_mode": "FULL", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }) as runner: + vllm_fullgraph_outputs = runner.model.generate(prompts, + sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, ) as runner: vllm_eager_outputs = runner.model.generate(prompts, sampling_params) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 4dbdefb6..2f56d9d2 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams, graph_mode_str = "PIECEWISE" if graph_mode == CUDAGraphMode.FULL: - graph_mode_str = "FULL" + graph_mode_str = "FULL_DECODE_ONLY" with VllmRunner( model_name, @@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams, enforce_eager=enforce_eager, max_model_len=2000, compilation_config=CompilationConfig( - cudagraph_mode=graph_mode_str), + cudagraph_mode=graph_mode_str, + cudagraph_capture_sizes=[12], + ), additional_config={"ascend_scheduler_config": { "enabled": False }}) as spec_llm: diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index e8fff182..e86c1332 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_flash_attention') def test_forward_prefill_no_cache(self, mock_flash_attention, - mock_reshape_cache): + mock_reshape_cache, + mock_get_forward_context): """Test forward pass in PrefillNoCache state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) + mock_get_forward_context.return_value = MagicMock(capturing=False) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillNoCache metadata.attn_mask = torch.randn(1, 1, 10, 10) @@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') - def test_forward_prefill_cache_hit(self, + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + def test_forward_prefill_cache_hit(self, mock_get_forward_context, mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache): """Test forward pass in PrefillCacheHit state""" @@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - mock_npu_fused_infer_attention_score.return_value = (output, 1) - metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillCacheHit metadata.attn_mask = torch.randn(1, 1, 10, 10) @@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_prefills = 10 layer = self.layer_no_quant + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_fused_infer_attention_score.return_value = (output, + torch.ones( + 10, 8, 64)) + output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) mock_npu_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') - def test_forward_decode_only(self, mock_paged_attention, + @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + def test_forward_decode_only(self, mock_get_forward_context, mock_npu_reshape_and_cache, - mock_get_forward_context): + mock_paged_attention): """Test forward pass in DecodeOnly state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('vllm_ascend.attention.attention_v1.get_graph_params') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_paged_attention') - @patch('torch.npu.graph_task_group_end') - @patch('torch.npu.graph_task_group_begin') - @patch('torch.npu.ExternalEvent') - @patch('torch_npu.npu.current_stream') - @patch('vllm_ascend.attention.attention_v1.weak_ref_tensors') - def test_paged_attention_with_existing_workspace( - self, - mock_get_forward_context, - mock_get_graph_params, - mock_npu_reshape_and_cache, - mock_paged_attention, - mock_graph_begin, - mock_graph_end, - mock_external_event_class, - mock_current_stream, - mock_weak_ref_tensors, - ): - graph_params = MagicMock() - attn_metadata = MagicMock() - num_tokens = 10 - - graph_params.workspaces = {num_tokens: 10} - graph_params.events = {num_tokens: []} - graph_params.attn_params = {num_tokens: []} - graph_params.handles = {num_tokens: []} - - query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size] - key_cache = MagicMock() - value_cache = MagicMock() - num_kv_heads = 4 - num_heads = 8 - scale = 0.1 - output = torch.randn(2, 5, 8) - - self_obj = MagicMock() - self_obj.key_cache = key_cache - self_obj.value_cache = value_cache - self_obj.num_kv_heads = num_kv_heads - self_obj.num_heads = num_heads - self_obj.scale = scale - - mock_stream = MagicMock() - mock_current_stream.return_value = mock_stream - mock_event_instance = MagicMock() - mock_external_event_class.return_value = mock_event_instance - - mock_handle = MagicMock() - mock_graph_end.return_value = mock_handle - - workspace = graph_params.workspaces.get(num_tokens) - self.assertEqual(workspace, 10) - - weak_ref_tensors = MagicMock(side_effect=lambda x: x) - - # 2. Handle graph capturing mode - stream = mock_current_stream() - event = mock_external_event_class() - event.wait(stream) - event.reset(stream) - graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( - weak_ref_tensors(query), - weak_ref_tensors(self_obj.key_cache), - weak_ref_tensors(self_obj.value_cache), - self_obj.num_kv_heads, - self_obj.num_heads, - self_obj.scale, - weak_ref_tensors(attn_metadata.block_tables), - attn_metadata.seq_lens, - output, - )) - - mock_event_instance.wait.assert_called_once_with(mock_stream) - mock_event_instance.reset.assert_called_once_with(mock_stream) - self.assertEqual(len(graph_params.events[num_tokens]), 1) - self.assertEqual(len(graph_params.attn_params[num_tokens]), 1) - - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.DecodeOnly - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - mock_get_forward_context.return_value = MagicMock(capturing=True) - mock_get_graph_params.return_value = graph_params - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_paged_attention.assert_called_once() - self.assertEqual(len(graph_params.handles[num_tokens]), 0) - - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') - def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, - mock_npu_reshape_and_cache): + @patch('torch_npu._npu_reshape_and_cache') + def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache, + mock_fused_infer_attention_score, + mock_get_forward_context): """Test forward pass in DecodeOnly state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty(10, 8, 64) + mock_get_forward_context.return_value = MagicMock(capturing=False) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly metadata.seq_lens = torch.tensor([10] * 10) @@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8, 64) @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') @patch('torch_npu.npu_fused_infer_attention_score') + @patch('torch_npu._npu_reshape_and_cache') def test_forward_decode_only_swa_seq_len_mismatch( - self, mock_fused_infer_attention_score, mock_paged_attention, - mock_npu_reshape_and_cache, mock_get_forward_context): + self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score, + mock_paged_attention, mock_get_forward_context): """Test forward pass in DecodeOnly state when seq)len_mismatch""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_decodes = 10 metadata.num_prefills = 0 - mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, - 64), 1) - mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64), + torch.ones(10, 8, 64)) + output = self.impl_swa.forward(layer, query, key, value, kv_cache, metadata, output) @@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') def test_forward_head_size_192(self, mock_vanilla_prefill, - mock_npu_reshape_and_cache, mock_is_310p): + mock_npu_reshape_and_cache, mock_is_310p, + mock_get_forward_context): """Test forward pass when head_size is 192""" self.impl.head_size = 192 @@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 192) output = torch.empty_like(query) + mock_get_forward_context.return_value = MagicMock(capturing=False) + metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase): mock_vanilla_prefill.assert_called_once() assert output.shape == (10, 8 * 192) - @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu.npu_fused_infer_attention_score') - def test_forward_normal_v1_situation(self, + @patch('torch_npu._npu_reshape_and_cache') + def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache, mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache): + mock_get_forward_context): """Test forward pass in normal V1 situation""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - mock_npu_fused_infer_attention_score.return_value = (output, 1) - metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_decodes = 0 metadata.num_prefills = 10 layer = self.layer_no_quant + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_fused_infer_attention_score.return_value = (output, + torch.ones( + 10, 8, 64)) output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) @@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) - def test_forward_310p_device(self, mock_is_310p, + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p, mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache, mock_npu_format_cast): @@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - mock_npu_fused_infer_attention_score.return_value = (output, 1) - metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase): mock_npu_format_cast.return_value = metadata.attn_mask + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_fused_infer_attention_score.return_value = (output, + torch.ones( + 10, 8, 64)) + output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 50f23a7b..32c2dc03 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -195,6 +195,7 @@ class AscendMetadataForDecode: class AscendMetadata: # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None + fia_attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill @@ -215,6 +216,7 @@ class AscendMetadata: seq_lens: torch.Tensor = None seq_lens_list: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore + query_start_loc_list: List[int] = None # type: ignore query_start_loc: torch.Tensor = None query_lens: torch.Tensor = None @@ -241,7 +243,8 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + AttentionCGSupport.ALWAYS + # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. @@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded] # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask + fia_attn_mask = common_attn_metadata.fia_attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs @@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, + query_start_loc_list=query_start_loc_cpu[1:].tolist(), query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), @@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder: actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), slot_mapping=slot_mapping, attn_mask=attn_mask, + fia_attn_mask=fia_attn_mask, attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, @@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl): self.dcp_group = get_dcp_group( ).device_group if self.dcp_size > 1 else None + def full_graph_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + num_tokens=0): + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.query_start_loc_list + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + # Normal V1 situation. + else: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + num_tokens = attn_metadata.query_start_loc_list[-1] + query = query[:num_tokens] + graph_params = get_graph_params() + query_start_loc = attn_metadata.query_start_loc_list + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.fia_attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + sparse_mode=3, + scale=self.scale, + ) + update_graph_params_workspaces(num_tokens, workspace) + + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append( + (weak_ref_tensors(query), weak_ref_tensors(key), + weak_ref_tensors(value), weak_ref_tensors(block_table), + weak_ref_tensors(attn_metadata.fia_attn_mask), block_size, + actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, + self.num_heads, self.scale, weak_ref_tensors(output), + weak_ref_tensors(softmax_lse))) + + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.fia_attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + workspace=workspace, + out=[output, softmax_lse], + ) + + output = output.view(num_tokens, self.num_heads, self.head_size) + + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + return output, num_tokens + def _forward_prefill_no_cache( self, query: torch.Tensor, @@ -692,70 +805,16 @@ class AscendAttentionBackendImpl(AttentionImpl): output = output.view(batch_size, self.num_heads, self.head_size) else: - graph_params = get_graph_params() - forward_context: ForwardContext = get_forward_context() - num_tokens = query.shape[0] - if forward_context.capturing: - # Get workspace from cache or calculate it if not present. - workspace = graph_params.workspaces.get(num_tokens) - if workspace is None: - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) - - # Handle graph capturing mode - stream = torch_npu.npu.current_stream() - - event = torch.npu.ExternalEvent() - event.wait(stream) - event.reset(stream) - graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( - weak_ref_tensors(query), - weak_ref_tensors(self.key_cache), - weak_ref_tensors(self.value_cache), - self.num_kv_heads, - self.num_heads, - self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens, - weak_ref_tensors(output), - )) - - torch.npu.graph_task_group_begin(stream) - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output, - workspace=workspace) - handle = torch.npu.graph_task_group_end(stream) - graph_params.handles[num_tokens].append(handle) - else: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) return output def _forward_v1_style( @@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl): scale=self.scale, sparse_mode=3, ) - return output def _attention_with_nomask_and_mask(self, q: torch.Tensor, @@ -1481,47 +1539,51 @@ class AscendAttentionBackendImpl(AttentionImpl): num_decode_tokens:attn_metadata. num_actual_tokens_pcp_padded]) - if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, kv_cache, attn_metadata, output) - elif attn_type == AttentionType.ENCODER_ONLY: - # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - intermediate_output = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - )[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - intermediate_output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - intermediate_output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - intermediate_output = self._forward_decode_only( - query, attn_metadata, output) - # Normal V1 situation. + forward_context: ForwardContext = get_forward_context() + if not forward_context.capturing: + if self.pcp_size * self.dcp_size > 1: + intermediate_output = self._forward_pcp_dcp( + query, key, value, kv_cache, attn_metadata, output) + elif attn_type == AttentionType.ENCODER_ONLY: + # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + intermediate_output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + intermediate_output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + intermediate_output = self._forward_prefill_cache_hit( + query, attn_metadata, output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + intermediate_output = self._forward_decode_only( + query, attn_metadata, output) + # Normal V1 situation. + else: + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] + intermediate_output = self._forward_v1_style( + query, attn_metadata, output) else: - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - intermediate_output = self._forward_v1_style( - query, attn_metadata, output) - + intermediate_output, num_tokens = self.full_graph_attention( + query, key, value, attn_metadata, output) output[:num_tokens] = intermediate_output[:num_tokens] return output diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 0650f3e3..62ed95c5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1278,8 +1278,7 @@ class AscendMLAImpl(MLAAttentionImpl): if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) softmax_lse = torch.empty(num_tokens, @@ -1779,8 +1778,7 @@ class AscendMLAImpl(MLAAttentionImpl): q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, seq_len, num_heads, self.scale, self.num_kv_heads, **common_kwargs) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index a2f71de7..e929dacc 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -88,6 +88,8 @@ class AscendCommonAttentionMetadata: attn_mask: torch.Tensor = None + fia_attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None attn_state: Any = None diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 8e72ebf0..3cb0613f 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape): graph_params.handles[runtime_shape], graph_params.events[runtime_shape], ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens + (query, key_cache, value, block_tables, attn_mask, block_size, + seq_lens, query_start_loc, num_kv_heads, num_heads, scale, + attn_output, softmax_lse) = param - # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY - # mode with GQA. This is triggered by getting workspace for _npu_paged_attention - # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens - # might encounter a bigger workspace, while currently we use max_model_len to - # calculate max workspace in capturing. So additional get_workspace is added - # here to avoid such bugs. - # TODO(Angazenn): we will remove this once _npu_paged_attention is fully - # replaced by npu_fused_infer_attention_score which does not contain such bugs. - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output) + seq_lens = forward_context.attn_metadata[key].seq_lens_list + query_start_loc = forward_context.attn_metadata[ + key].query_start_loc_list torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=workspace) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key_cache, + value=value, + block_table=block_tables, + atten_mask=attn_mask, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=seq_lens, + num_key_value_heads=num_kv_heads, + num_heads=num_heads, + scale=scale, + sparse_mode=3, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) -def update_graph_params_workspaces(num_tokens: int, workspace: Any): +def update_graph_params_workspaces(num_tokens: int, workspace: int): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = workspace + _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) def get_graph_params(): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 1c2a3391..faed5aea 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -233,7 +233,8 @@ class NPUPlatform(Platform): "vllm.mla_forward" ]) update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") @@ -270,7 +271,8 @@ class NPUPlatform(Platform): compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 992ea103..792972f0 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -21,6 +21,7 @@ import math import types from typing import Any, Optional +import numpy as np import torch import torch.distributed as dist import torch.nn as nn @@ -31,7 +32,6 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import logger -import numpy as np import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 29720464..3c9fc126 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None + self.fia_attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: return None + def _make_fia_attention_mask(self) -> torch.Tensor: + if self.attn_mask_builder is None: + raise ValueError("Attn mask builder is None") + return self.attn_mask_builder.get_splitfuse_attn_mask() + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): @@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.fia_attn_mask = self._make_fia_attention_mask() self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill @@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, + fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, is_only_prefill=bool(np.all(num_valid_tokens != 1)), @@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin): cu_num_tokens, arange = self._get_cumsum_and_arange( num_scheduled_tokens) - query_start_loc_tensor = torch.Tensor(cu_num_tokens).to( - self.device).to(torch.int32) - self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor + + self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens) + assigned_mask_dim = 2048 + self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim, + assigned_mask_dim), + diagonal=1).to(torch.int8).to( + self.device) + num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, + fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, max_query_len=max_query_len, @@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): graph_support = None if hasattr(builder, 'aclgraph_support'): graph_support = builder.aclgraph_support.value + builder_aclgraph = builder.aclgraph_support else: graph_support = builder.cudagraph_support.value + builder_aclgraph = builder.cudagraph_support if graph_support < min_ag_support.value: - min_ag_support = builder.aclgraph_support + min_ag_support = builder_aclgraph min_ag_builder_name = builder.__class__.__name__ # This is an imitation of compilation_config.splitting_ops_contain_attention()