diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 98f91533..be5b43e6 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -180,6 +180,7 @@ jobs: if: ${{ inputs.type == 'full' }} run: | pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py + pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py pytest -sv tests/e2e/multicard/test_external_launcher.py diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py index 3b9f2932..3ccbf823 100644 --- a/tests/e2e/multicard/test_full_graph_mode.py +++ b/tests/e2e/multicard/test_full_graph_mode.py @@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner from tests.e2e.model_utils import check_outputs_equal -def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY(): if 'HCCL_OP_EXPANSION_MODE' in os.environ: del os.environ['HCCL_OP_EXPANSION_MODE'] prompts = [ @@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): max_model_len=1024, tensor_parallel_size=2, enforce_eager=False, - compilation_config={"cudagraph_mode": - "FULL_DECODE_ONLY"}) as runner: + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }) as runner: vllm_fullgraph_outputs = runner.model.generate(prompts, sampling_params) with VllmRunner( model, max_model_len=1024, - enforce_eager=True, + tensor_parallel_size=2, + enforce_eager=False, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_fullgraph_outputs_list = [] + for output in vllm_fullgraph_outputs: + vllm_fullgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_fullgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_fullgraph_outputs", + ) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + model = "Qwen/Qwen3-30B-A3B" + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + with VllmRunner(model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={ + "cudagraph_mode": "FULL", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }) as runner: + vllm_fullgraph_outputs = runner.model.generate(prompts, + sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, ) as runner: vllm_eager_outputs = runner.model.generate(prompts, sampling_params) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 4dbdefb6..2f56d9d2 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams, graph_mode_str = "PIECEWISE" if graph_mode == CUDAGraphMode.FULL: - graph_mode_str = "FULL" + graph_mode_str = "FULL_DECODE_ONLY" with VllmRunner( model_name, @@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams, enforce_eager=enforce_eager, max_model_len=2000, compilation_config=CompilationConfig( - cudagraph_mode=graph_mode_str), + cudagraph_mode=graph_mode_str, + cudagraph_capture_sizes=[12], + ), additional_config={"ascend_scheduler_config": { "enabled": False }}) as spec_llm: diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index e8fff182..e86c1332 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_flash_attention') def test_forward_prefill_no_cache(self, mock_flash_attention, - mock_reshape_cache): + mock_reshape_cache, + mock_get_forward_context): """Test forward pass in PrefillNoCache state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) + mock_get_forward_context.return_value = MagicMock(capturing=False) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillNoCache metadata.attn_mask = torch.randn(1, 1, 10, 10) @@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') - def test_forward_prefill_cache_hit(self, + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + def test_forward_prefill_cache_hit(self, mock_get_forward_context, mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache): """Test forward pass in PrefillCacheHit state""" @@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - mock_npu_fused_infer_attention_score.return_value = (output, 1) - metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillCacheHit metadata.attn_mask = torch.randn(1, 1, 10, 10) @@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_prefills = 10 layer = self.layer_no_quant + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_fused_infer_attention_score.return_value = (output, + torch.ones( + 10, 8, 64)) + output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) mock_npu_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') - def test_forward_decode_only(self, mock_paged_attention, + @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + def test_forward_decode_only(self, mock_get_forward_context, mock_npu_reshape_and_cache, - mock_get_forward_context): + mock_paged_attention): """Test forward pass in DecodeOnly state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('vllm_ascend.attention.attention_v1.get_graph_params') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_paged_attention') - @patch('torch.npu.graph_task_group_end') - @patch('torch.npu.graph_task_group_begin') - @patch('torch.npu.ExternalEvent') - @patch('torch_npu.npu.current_stream') - @patch('vllm_ascend.attention.attention_v1.weak_ref_tensors') - def test_paged_attention_with_existing_workspace( - self, - mock_get_forward_context, - mock_get_graph_params, - mock_npu_reshape_and_cache, - mock_paged_attention, - mock_graph_begin, - mock_graph_end, - mock_external_event_class, - mock_current_stream, - mock_weak_ref_tensors, - ): - graph_params = MagicMock() - attn_metadata = MagicMock() - num_tokens = 10 - - graph_params.workspaces = {num_tokens: 10} - graph_params.events = {num_tokens: []} - graph_params.attn_params = {num_tokens: []} - graph_params.handles = {num_tokens: []} - - query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size] - key_cache = MagicMock() - value_cache = MagicMock() - num_kv_heads = 4 - num_heads = 8 - scale = 0.1 - output = torch.randn(2, 5, 8) - - self_obj = MagicMock() - self_obj.key_cache = key_cache - self_obj.value_cache = value_cache - self_obj.num_kv_heads = num_kv_heads - self_obj.num_heads = num_heads - self_obj.scale = scale - - mock_stream = MagicMock() - mock_current_stream.return_value = mock_stream - mock_event_instance = MagicMock() - mock_external_event_class.return_value = mock_event_instance - - mock_handle = MagicMock() - mock_graph_end.return_value = mock_handle - - workspace = graph_params.workspaces.get(num_tokens) - self.assertEqual(workspace, 10) - - weak_ref_tensors = MagicMock(side_effect=lambda x: x) - - # 2. Handle graph capturing mode - stream = mock_current_stream() - event = mock_external_event_class() - event.wait(stream) - event.reset(stream) - graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( - weak_ref_tensors(query), - weak_ref_tensors(self_obj.key_cache), - weak_ref_tensors(self_obj.value_cache), - self_obj.num_kv_heads, - self_obj.num_heads, - self_obj.scale, - weak_ref_tensors(attn_metadata.block_tables), - attn_metadata.seq_lens, - output, - )) - - mock_event_instance.wait.assert_called_once_with(mock_stream) - mock_event_instance.reset.assert_called_once_with(mock_stream) - self.assertEqual(len(graph_params.events[num_tokens]), 1) - self.assertEqual(len(graph_params.attn_params[num_tokens]), 1) - - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.DecodeOnly - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - mock_get_forward_context.return_value = MagicMock(capturing=True) - mock_get_graph_params.return_value = graph_params - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_paged_attention.assert_called_once() - self.assertEqual(len(graph_params.handles[num_tokens]), 0) - - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') - def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, - mock_npu_reshape_and_cache): + @patch('torch_npu._npu_reshape_and_cache') + def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache, + mock_fused_infer_attention_score, + mock_get_forward_context): """Test forward pass in DecodeOnly state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty(10, 8, 64) + mock_get_forward_context.return_value = MagicMock(capturing=False) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly metadata.seq_lens = torch.tensor([10] * 10) @@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8, 64) @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') @patch('torch_npu.npu_fused_infer_attention_score') + @patch('torch_npu._npu_reshape_and_cache') def test_forward_decode_only_swa_seq_len_mismatch( - self, mock_fused_infer_attention_score, mock_paged_attention, - mock_npu_reshape_and_cache, mock_get_forward_context): + self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score, + mock_paged_attention, mock_get_forward_context): """Test forward pass in DecodeOnly state when seq)len_mismatch""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_decodes = 10 metadata.num_prefills = 0 - mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, - 64), 1) - mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64), + torch.ones(10, 8, 64)) + output = self.impl_swa.forward(layer, query, key, value, kv_cache, metadata, output) @@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') def test_forward_head_size_192(self, mock_vanilla_prefill, - mock_npu_reshape_and_cache, mock_is_310p): + mock_npu_reshape_and_cache, mock_is_310p, + mock_get_forward_context): """Test forward pass when head_size is 192""" self.impl.head_size = 192 @@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 192) output = torch.empty_like(query) + mock_get_forward_context.return_value = MagicMock(capturing=False) + metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase): mock_vanilla_prefill.assert_called_once() assert output.shape == (10, 8 * 192) - @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu.npu_fused_infer_attention_score') - def test_forward_normal_v1_situation(self, + @patch('torch_npu._npu_reshape_and_cache') + def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache, mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache): + mock_get_forward_context): """Test forward pass in normal V1 situation""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - mock_npu_fused_infer_attention_score.return_value = (output, 1) - metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_decodes = 0 metadata.num_prefills = 10 layer = self.layer_no_quant + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_fused_infer_attention_score.return_value = (output, + torch.ones( + 10, 8, 64)) output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) @@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) - def test_forward_310p_device(self, mock_is_310p, + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p, mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache, mock_npu_format_cast): @@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase): kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - mock_npu_fused_infer_attention_score.return_value = (output, 1) - metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase): mock_npu_format_cast.return_value = metadata.attn_mask + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_fused_infer_attention_score.return_value = (output, + torch.ones( + 10, 8, 64)) + output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 50f23a7b..32c2dc03 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -195,6 +195,7 @@ class AscendMetadataForDecode: class AscendMetadata: # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None + fia_attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill @@ -215,6 +216,7 @@ class AscendMetadata: seq_lens: torch.Tensor = None seq_lens_list: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore + query_start_loc_list: List[int] = None # type: ignore query_start_loc: torch.Tensor = None query_lens: torch.Tensor = None @@ -241,7 +243,8 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + AttentionCGSupport.ALWAYS + # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. @@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded] # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask + fia_attn_mask = common_attn_metadata.fia_attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs @@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, + query_start_loc_list=query_start_loc_cpu[1:].tolist(), query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), @@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder: actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), slot_mapping=slot_mapping, attn_mask=attn_mask, + fia_attn_mask=fia_attn_mask, attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, @@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl): self.dcp_group = get_dcp_group( ).device_group if self.dcp_size > 1 else None + def full_graph_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + num_tokens=0): + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.query_start_loc_list + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + # Normal V1 situation. + else: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + num_tokens = attn_metadata.query_start_loc_list[-1] + query = query[:num_tokens] + graph_params = get_graph_params() + query_start_loc = attn_metadata.query_start_loc_list + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.fia_attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + sparse_mode=3, + scale=self.scale, + ) + update_graph_params_workspaces(num_tokens, workspace) + + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append( + (weak_ref_tensors(query), weak_ref_tensors(key), + weak_ref_tensors(value), weak_ref_tensors(block_table), + weak_ref_tensors(attn_metadata.fia_attn_mask), block_size, + actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, + self.num_heads, self.scale, weak_ref_tensors(output), + weak_ref_tensors(softmax_lse))) + + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.fia_attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + workspace=workspace, + out=[output, softmax_lse], + ) + + output = output.view(num_tokens, self.num_heads, self.head_size) + + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + return output, num_tokens + def _forward_prefill_no_cache( self, query: torch.Tensor, @@ -692,70 +805,16 @@ class AscendAttentionBackendImpl(AttentionImpl): output = output.view(batch_size, self.num_heads, self.head_size) else: - graph_params = get_graph_params() - forward_context: ForwardContext = get_forward_context() - num_tokens = query.shape[0] - if forward_context.capturing: - # Get workspace from cache or calculate it if not present. - workspace = graph_params.workspaces.get(num_tokens) - if workspace is None: - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) - - # Handle graph capturing mode - stream = torch_npu.npu.current_stream() - - event = torch.npu.ExternalEvent() - event.wait(stream) - event.reset(stream) - graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( - weak_ref_tensors(query), - weak_ref_tensors(self.key_cache), - weak_ref_tensors(self.value_cache), - self.num_kv_heads, - self.num_heads, - self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens, - weak_ref_tensors(output), - )) - - torch.npu.graph_task_group_begin(stream) - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output, - workspace=workspace) - handle = torch.npu.graph_task_group_end(stream) - graph_params.handles[num_tokens].append(handle) - else: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) return output def _forward_v1_style( @@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl): scale=self.scale, sparse_mode=3, ) - return output def _attention_with_nomask_and_mask(self, q: torch.Tensor, @@ -1481,47 +1539,51 @@ class AscendAttentionBackendImpl(AttentionImpl): num_decode_tokens:attn_metadata. num_actual_tokens_pcp_padded]) - if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, kv_cache, attn_metadata, output) - elif attn_type == AttentionType.ENCODER_ONLY: - # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - intermediate_output = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - )[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - intermediate_output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - intermediate_output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - intermediate_output = self._forward_decode_only( - query, attn_metadata, output) - # Normal V1 situation. + forward_context: ForwardContext = get_forward_context() + if not forward_context.capturing: + if self.pcp_size * self.dcp_size > 1: + intermediate_output = self._forward_pcp_dcp( + query, key, value, kv_cache, attn_metadata, output) + elif attn_type == AttentionType.ENCODER_ONLY: + # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + intermediate_output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + intermediate_output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + intermediate_output = self._forward_prefill_cache_hit( + query, attn_metadata, output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + intermediate_output = self._forward_decode_only( + query, attn_metadata, output) + # Normal V1 situation. + else: + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] + intermediate_output = self._forward_v1_style( + query, attn_metadata, output) else: - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - intermediate_output = self._forward_v1_style( - query, attn_metadata, output) - + intermediate_output, num_tokens = self.full_graph_attention( + query, key, value, attn_metadata, output) output[:num_tokens] = intermediate_output[:num_tokens] return output diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 0650f3e3..62ed95c5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1278,8 +1278,7 @@ class AscendMLAImpl(MLAAttentionImpl): if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) softmax_lse = torch.empty(num_tokens, @@ -1779,8 +1778,7 @@ class AscendMLAImpl(MLAAttentionImpl): q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, seq_len, num_heads, self.scale, self.num_kv_heads, **common_kwargs) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index a2f71de7..e929dacc 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -88,6 +88,8 @@ class AscendCommonAttentionMetadata: attn_mask: torch.Tensor = None + fia_attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None attn_state: Any = None diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 8e72ebf0..3cb0613f 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape): graph_params.handles[runtime_shape], graph_params.events[runtime_shape], ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens + (query, key_cache, value, block_tables, attn_mask, block_size, + seq_lens, query_start_loc, num_kv_heads, num_heads, scale, + attn_output, softmax_lse) = param - # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY - # mode with GQA. This is triggered by getting workspace for _npu_paged_attention - # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens - # might encounter a bigger workspace, while currently we use max_model_len to - # calculate max workspace in capturing. So additional get_workspace is added - # here to avoid such bugs. - # TODO(Angazenn): we will remove this once _npu_paged_attention is fully - # replaced by npu_fused_infer_attention_score which does not contain such bugs. - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output) + seq_lens = forward_context.attn_metadata[key].seq_lens_list + query_start_loc = forward_context.attn_metadata[ + key].query_start_loc_list torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=workspace) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key_cache, + value=value, + block_table=block_tables, + atten_mask=attn_mask, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=seq_lens, + num_key_value_heads=num_kv_heads, + num_heads=num_heads, + scale=scale, + sparse_mode=3, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) -def update_graph_params_workspaces(num_tokens: int, workspace: Any): +def update_graph_params_workspaces(num_tokens: int, workspace: int): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = workspace + _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) def get_graph_params(): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 1c2a3391..faed5aea 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -233,7 +233,8 @@ class NPUPlatform(Platform): "vllm.mla_forward" ]) update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") @@ -270,7 +271,8 @@ class NPUPlatform(Platform): compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 992ea103..792972f0 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -21,6 +21,7 @@ import math import types from typing import Any, Optional +import numpy as np import torch import torch.distributed as dist import torch.nn as nn @@ -31,7 +32,6 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import logger -import numpy as np import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 29720464..3c9fc126 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None + self.fia_attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: return None + def _make_fia_attention_mask(self) -> torch.Tensor: + if self.attn_mask_builder is None: + raise ValueError("Attn mask builder is None") + return self.attn_mask_builder.get_splitfuse_attn_mask() + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): @@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) + self.fia_attn_mask = self._make_fia_attention_mask() self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill @@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, + fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, is_only_prefill=bool(np.all(num_valid_tokens != 1)), @@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin): cu_num_tokens, arange = self._get_cumsum_and_arange( num_scheduled_tokens) - query_start_loc_tensor = torch.Tensor(cu_num_tokens).to( - self.device).to(torch.int32) - self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor + + self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens) + assigned_mask_dim = 2048 + self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim, + assigned_mask_dim), + diagonal=1).to(torch.int8).to( + self.device) + num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, + fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, max_query_len=max_query_len, @@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): graph_support = None if hasattr(builder, 'aclgraph_support'): graph_support = builder.aclgraph_support.value + builder_aclgraph = builder.aclgraph_support else: graph_support = builder.cudagraph_support.value + builder_aclgraph = builder.cudagraph_support if graph_support < min_ag_support.value: - min_ag_support = builder.aclgraph_support + min_ag_support = builder_aclgraph min_ag_builder_name = builder.__class__.__name__ # This is an imitation of compilation_config.splitting_ops_contain_attention()