From 0ca3f48c900b333673830e8307c259acc684c1a3 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Thu, 21 Aug 2025 14:02:30 +0800 Subject: [PATCH] [2/N][refactor] torchair deepseek mla backend refactor (#2459) ### What this PR does / why we need it? This PR move current unified mla backend to torchair folder and remove torchair-related code in attention/mla_v1.py (1.3k -> 0.9k). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Running eager mode with mla backend, and torchair mode with code before [2445](https://github.com/vllm-project/vllm-ascend/pull/2445) - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/f571ff8eb6d9117c6a418f7f925921968dff8ac8 Signed-off-by: linfeng-yuan <1102311262@qq.com> --- tests/ut/attention/test_mla_v1.py | 260 +---- tests/ut/test_platform.py | 21 + tests/ut/torchair/test_torchair_mla.py | 753 ++++++++++++++ vllm_ascend/attention/mla_v1.py | 552 ++-------- vllm_ascend/platform.py | 18 +- vllm_ascend/torchair/torchair_mla.py | 1319 ++++++++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 16 +- 7 files changed, 2192 insertions(+), 747 deletions(-) create mode 100644 tests/ut/torchair/test_torchair_mla.py create mode 100644 vllm_ascend/torchair/torchair_mla.py diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index be2a7d8..3ca7210 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -11,7 +11,6 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend, AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, AscendMLAPrefillMetadata) -from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata class TestAscendMLABackend(TestBase): @@ -188,8 +187,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_device = 'cpu' ascend_config = MagicMock() - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = True with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) @@ -199,44 +196,9 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertEqual( builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.chunked_prefill_enabled) - self.assertEqual(builder.torchair_graph_enabled, True) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_reorder_batch_with_torchair_graph(self, ascend_config): - mock_vllm_config = MagicMock() - mock_vllm_config.model_config.max_model_len = 1024 - mock_vllm_config.cache_config.block_size = 16 - mock_vllm_config.scheduler_config.max_num_seqs = 4 - mock_vllm_config.scheduler_config.chunked_prefill_enabled = False - mock_device = 'cpu' - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = True - - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - - input_batch = MagicMock() - input_batch.req_ids = [0, 1, 2, 3] - - scheduler_output = MagicMock() - scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1} - scheduler_output.scheduled_spec_decode_tokens = { - 0: [1], - 1: [], - 2: [1, 1], - 3: [] - } - - input_batch.swap_states = MagicMock() - - modified = builder.reorder_batch(input_batch, scheduler_output) - - self.assertFalse(modified) - input_batch.swap_states.assert_not_called() - - def test_reorder_batch_without_torchair_graph(self): + def test_reorder_batch(self): ascend_config = MagicMock() - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = False mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 @@ -268,128 +230,6 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertTrue(modified) input_batch.swap_states.assert_called_once_with(1, 2) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - mock_vllm_config = MagicMock() - mock_vllm_config.model_config.max_model_len = 1024 - mock_vllm_config.cache_config.block_size = 16 - mock_vllm_config.scheduler_config.chunked_prefill_enabled = False - mock_device = 'cpu' - - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 64) - self.assertTrue(torch.equal(result[:, :10], block_tables)) - - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - mock_vllm_config = MagicMock() - mock_vllm_config.model_config.max_model_len = 64 - mock_vllm_config.cache_config.block_size = 16 - mock_vllm_config.scheduler_config.chunked_prefill_enabled = False - mock_device = 'cpu' - - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 4) - self.assertTrue(torch.equal(result, block_tables[:, :4])) - - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_get_graph_runner_block_tables_from_numpy(self, - mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - mock_vllm_config = MagicMock() - mock_vllm_config.model_config.max_model_len = 1024 - mock_vllm_config.cache_config.block_size = 16 - mock_vllm_config.scheduler_config.chunked_prefill_enabled = False - mock_device = 'cpu' - - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 64) - self.assertTrue(torch.equal(result[:, :10], block_tables)) - - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_build_dummy(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config.max_model_len = 1024 - mock_vllm_config.cache_config.block_size = 16 - mock_vllm_config.scheduler_config.chunked_prefill_enabled = False - mock_vllm_config.get_head_size.return_value = 64 - mock_vllm_config.model_config.dtype = torch.float16 - mock_device = 'cpu' - - builder = AscendMLAMetadataBuilder(mock_vllm_config, - mock_device, - metadata_cls=AscendMLAMetadata) - builder.rope_dim = 64 - - with patch.object(builder, - "_get_graph_runner_block_tables", - side_effect=lambda x, y: y): - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=3, - num_actual_tokens=3, - decode_token_per_req=1, - actual_seq_lengths_q=[0, 1, 2], - attn_mask=torch.zeros((1, 1), dtype=torch.bool), - spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool), - ) - metadata = builder.build_torchair_graph_dummy(common_attn_metadata) - - sin_golden = torch.ones(3, - 1, - 1, - 64, - dtype=torch.float16, - device=mock_device) - cos_golden = torch.ones(3, - 1, - 1, - 64, - dtype=torch.float16, - device=mock_device) - - self.assertIsInstance(metadata, AscendMLAMetadata) - self.assertEqual(metadata.num_input_tokens, 3) - self.assertEqual(metadata.num_actual_tokens, 3) - self.assertEqual(metadata.num_decodes, 1) - self.assertEqual(metadata.num_decode_tokens, 1) - self.assertEqual(metadata.num_prefills, 0) - self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly) - self.assertIsNone(metadata.prefill) - self.assertIsInstance(metadata.decode, AscendMLADecodeMetadata) - self.assertEqual(metadata.block_tables.shape[0], 3) - self.assertEqual(metadata.block_tables.shape[1], 64) - self.assertEqual(metadata.seq_lens.shape[0], 3) - self.assertEqual(metadata.slot_mapping.shape[0], 3) - self.assertEqual(metadata.query_start_loc.shape[0], 3) - assert torch.equal(sin_golden, metadata.decode.sin) - assert torch.equal(cos_golden, metadata.decode.cos) - class TestAscendMLAImpl(TestBase): @@ -401,8 +241,6 @@ class TestAscendMLAImpl(TestBase): @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp): mock_tp.world_size = 2 - ascend_config.torchair_graph_config.enabled = True - ascend_config.torchair_graph_config.enable_kv_nz = False speculative_config = MagicMock() speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config @@ -464,7 +302,6 @@ class TestAscendMLAImpl(TestBase): self.assertIsNotNone(self.impl.kv_a_layernorm) self.assertEqual(self.impl.num_queries_per_kv, 32) self.assertEqual(self.impl.tp_size, 2) - self.assertTrue(self.impl.torchair_graph_enabled) def test_v_up_proj_and_o_proj(self): batch_size = 4 @@ -580,102 +417,10 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(lse.shape, prefix_lse.shape) - @patch("torch_npu.npu_kv_rmsnorm_rope_cache") - def test_exec_kv(self, mock_kv_cache): - batch_size = 2 - hidden = torch.randn(batch_size, 128) - cos = torch.randn(batch_size, 32) - sin = torch.randn(batch_size, 32) - kv_cache = (torch.randn( - 4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), - torch.randn( - 4, 8, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)) - slots = torch.arange(batch_size, dtype=torch.long) - - proj_out = torch.randn( - batch_size, self.impl.num_kv_heads, 1, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim) - self.impl.kv_a_proj_with_mqa.return_value = (proj_out, ) - - mock_kv_cache.return_value = (torch.randn(batch_size, - self.impl.num_kv_heads, 1, - self.impl.qk_rope_head_dim), - torch.randn(batch_size, - self.impl.num_kv_heads, 1, - self.impl.kv_lora_rank), - None, None) - - k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots) - - self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden) - mock_kv_cache.assert_called_once() - self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1, - self.impl.qk_rope_head_dim)) - self.assertEqual( - k_nope.shape, - (batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank)) - self.assertEqual(kv.shape, - (batch_size, self.impl.num_kv_heads, 1, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)) - - @patch("torch_npu.npu_kv_rmsnorm_rope_cache") - def test_exec_kv_prefill(self, mock_kv): - B, N, S, H = 2, self.impl.num_kv_heads, 1, 128 - hidden_states = torch.randn(B, N, S, H) - cos = torch.randn(B, S, 32) - sin = torch.randn(B, S, 32) - kv_cache = ( - torch.randn(100, 8, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), - torch.randn(100, 8, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), - ) - - slots = torch.arange(B * S, dtype=torch.long) - - proj_out = torch.randn( - B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim) - self.impl.kv_a_proj_with_mqa.return_value = (proj_out, ) - - mock_kv.return_value = (None, None, - torch.randn(B, self.impl.num_kv_heads, S, - self.impl.qk_rope_head_dim), - torch.randn(B, self.impl.num_kv_heads, S, - self.impl.kv_lora_rank)) - - k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin, - kv_cache, slots) - - self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states) - mock_kv.assert_called_once() - - self.assertEqual( - k_pe.shape, - (B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim)) - self.assertEqual( - k_nope.shape, - (B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank)) - - @patch("torch_npu.npu_interleave_rope") - def test_rope_single(self, mock_rope): - B, N, D = 2, 16, 1024 - x = torch.randn(B, N, D) - cos = torch.randn(B, N, 1, D) - sin = torch.randn(B, N, 1, D) - mock_rope.return_value = x.view(B, N, 1, D) - result = self.impl.rope_single(x, cos, sin) - self.assertEqual(result.shape[0], B) - self.assertEqual(result.shape[1], N) - self.assertEqual(result.shape[2], D) - mock_rope.assert_called_once() - @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj") @patch("torch_npu._npu_paged_attention_mla") def test_forward_decode_without_graph(self, mock_page_attention_mla, mock_up_proj): - self.impl.running_in_graph = False - self.impl.running_chunkprefilll_with_torchair = False num_tokens = 100 num_blocks = 256 block_size = 4 @@ -706,9 +451,6 @@ class TestAscendMLAImpl(TestBase): @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill") @patch("torch_npu._npu_reshape_and_cache") def test_forward_without_graph(self, _, mock_forward_prefill): - self.impl.running_in_graph = False - self.impl.torchair_graph_enabled = False - num_tokens = 100 num_blocks = 256 block_size = 4 diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 67436a3..c44880e 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -425,6 +425,27 @@ class TestNPUPlatform(TestBase): self.assertEqual(result, "vllm_ascend.attention.mla_v1.AscendMLABackend") + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_mla_and_torchair( + self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = True + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=True, + use_mla=True, + ) + self.assertEqual( + result, + "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend") + @patch('vllm_ascend.platform.get_ascend_config') def test_get_attn_backend_cls_use_v1_and_torchair(self, mock_get_ascend_config): diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py new file mode 100644 index 0000000..8a6c14d --- /dev/null +++ b/tests/ut/torchair/test_torchair_mla.py @@ -0,0 +1,753 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.linear import LinearBase + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.torchair.torchair_mla import ( + AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata, + AscendMLATorchairImpl, AscendMLATorchairMetadata, + AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata) +from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata + + +class TestAscendMLATorchairBackend(TestBase): + + def test_get_name(self): + self.assertEqual(AscendMLATorchairBackend.get_name(), + "ASCEND_MLA_TORCHAIR") + + def test_get_metadata_cls(self): + self.assertEqual(AscendMLATorchairBackend.get_metadata_cls(), + AscendMLATorchairMetadata) + + def test_get_builder_cls(self): + self.assertEqual(AscendMLATorchairBackend.get_builder_cls(), + AscendMLATorchairMetadataBuilder) + + def test_get_kv_cache_shape(self): + result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128) + self.assertEqual(result, (2, 4, 8, 128)) + + def test_get_impl_cls(self): + result = AscendMLATorchairBackend.get_impl_cls() + self.assertEqual(result, AscendMLATorchairImpl) + + +class TestAscendMLATorchairPrefillMetadata(TestBase): + + def test_ascend_mla_prefill_metadata_default(self): + attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool) + query_lens = [1, 2] + seq_lens = [2, 2] + context_lens = torch.tensor([1, 2]) + input_positions = torch.tensor([0, 1, 0, 1]) + query_start_loc = torch.tensor([0, 1, 3]) + block_table = torch.tensor([[0, 1], [2, 3]]) + max_query_len = 2 + max_seq_lens = 2 + + metadata = AscendMLATorchairPrefillMetadata( + attn_mask=attn_mask, + query_lens=query_lens, + seq_lens=seq_lens, + context_lens=context_lens, + input_positions=input_positions, + query_start_loc=query_start_loc, + block_table=block_table, + max_query_len=max_query_len, + max_seq_lens=max_seq_lens) + self.assertIs(metadata.attn_mask, attn_mask) + self.assertEqual(metadata.query_lens, query_lens) + self.assertEqual(metadata.seq_lens, seq_lens) + self.assertIs(metadata.context_lens, context_lens) + self.assertIs(metadata.input_positions, input_positions) + self.assertIs(metadata.query_start_loc, query_start_loc) + self.assertIs(metadata.block_table, block_table) + self.assertEqual(metadata.max_query_len, max_query_len) + self.assertEqual(metadata.max_seq_lens, max_seq_lens) + self.assertIsNone(metadata.chunked_context) + + def test_ascend_mla_prefill_metadata_with_chunked_context(self): + cu_seq_lens = torch.tensor([0, 2, 4]) + starts = torch.tensor([0, 2]) + seq_tot = [2, 2] + max_seq_lens = [2, 2] + workspace = torch.randn(2, 4) + chunk_seq_lens = torch.tensor([2, 2]) + + chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata( + cu_seq_lens=cu_seq_lens, + starts=starts, + seq_tot=seq_tot, + max_seq_lens=max_seq_lens, + workspace=workspace, + chunk_seq_lens=chunk_seq_lens) + + metadata = AscendMLATorchairPrefillMetadata( + attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), + query_lens=[1, 2], + seq_lens=[2, 2], + context_lens=torch.tensor([1, 2]), + input_positions=torch.tensor([0, 1, 0, 1]), + query_start_loc=torch.tensor([0, 1, 3]), + block_table=torch.tensor([[0, 1], [2, 3]]), + max_query_len=2, + max_seq_lens=2, + chunked_context=chunked_context) + + self.assertIsNotNone(metadata.chunked_context) + self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens) + self.assertIs(metadata.chunked_context.starts, starts) + self.assertEqual(metadata.chunked_context.seq_tot, seq_tot) + self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) + self.assertIs(metadata.chunked_context.workspace, workspace) + self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + + +class TestAscendMLATorchairDecodeMetadata(TestBase): + + def test_ascend_mla_decode_metadata_default(self): + input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) + block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]]) + seq_lens = torch.tensor([[2], [3]]) + max_seq_lens = 4 + seq_lens_list = [2, 3] + attn_mask = None + + metadata = AscendMLATorchairDecodeMetadata(input_positions, + block_table, seq_lens, + max_seq_lens, seq_lens_list, + attn_mask) + + self.assertIs(metadata.input_positions, input_positions) + self.assertIs(metadata.block_table, block_table) + self.assertIs(metadata.seq_lens, seq_lens) + self.assertEqual(metadata.max_seq_lens, max_seq_lens) + self.assertEqual(metadata.seq_lens_list, seq_lens_list) + self.assertIsNone(attn_mask) + + +class TestAscendMLATorchairMetadata(TestBase): + + def test_ascend_mla_metadata_default(self): + num_actual_tokens = 100 + slot_mapping = torch.randn(100, 4, 1024) + query_start_loc = torch.tensor([1, 2, 3, 4]) + seq_lens = [30, 50] + block_tables = torch.randint(0, 100, (100, 4)) + + num_decodes = 4 + num_decode_tokens = 8 + num_prefills = 8 + + num_input_tokens = 2 + + query_lens = None + head_dim = None + attn_mask = None + attn_state = AscendAttentionState.ChunkedPrefill + + decode = None + prefill = None + + metadata = AscendMLATorchairMetadata( + num_actual_tokens, slot_mapping, query_start_loc, seq_lens, + block_tables, num_decodes, num_decode_tokens, num_prefills, + num_input_tokens, query_lens, head_dim, attn_mask, attn_state, + decode, prefill) + + self.assertEqual(metadata.num_actual_tokens, num_actual_tokens) + self.assertIs(metadata.slot_mapping, slot_mapping) + self.assertIs(metadata.query_start_loc, query_start_loc) + self.assertEqual(metadata.seq_lens, seq_lens) + self.assertIs(metadata.block_tables, block_tables) + self.assertEqual(metadata.num_decodes, num_decodes) + self.assertEqual(metadata.num_decode_tokens, num_decode_tokens) + self.assertEqual(metadata.num_prefills, num_prefills) + self.assertEqual(metadata.num_input_tokens, num_input_tokens) + self.assertEqual(metadata.query_lens, query_lens) + self.assertEqual(metadata.head_dim, head_dim) + self.assertEqual(metadata.attn_mask, attn_mask) + self.assertEqual(metadata.attn_state, attn_state) + self.assertEqual(metadata.decode, decode) + self.assertEqual(metadata.prefill, prefill) + + +class TestAscendMLATorchairMetadataBuilder(TestBase): + + def test_ascend_mla_metadata_builder_default(self): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + ascend_config = MagicMock() + ascend_config.torchair_graph_config = MagicMock() + ascend_config.torchair_graph_config.enabled = True + with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", + return_value=ascend_config): + builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_device) + + self.assertEqual(builder.block_size, + mock_vllm_config.cache_config.block_size) + self.assertEqual( + builder.chunked_prefill_enabled, + mock_vllm_config.scheduler_config.chunked_prefill_enabled) + self.assertEqual(builder.torchair_graph_enabled, True) + + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def test_reorder_batch_with_torchair_graph(self, ascend_config): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + ascend_config.torchair_graph_config = MagicMock() + ascend_config.torchair_graph_config.enabled = True + + builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_device) + + input_batch = MagicMock() + input_batch.req_ids = [0, 1, 2, 3] + + scheduler_output = MagicMock() + scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1} + scheduler_output.scheduled_spec_decode_tokens = { + 0: [1], + 1: [], + 2: [1, 1], + 3: [] + } + + input_batch.swap_states = MagicMock() + + modified = builder.reorder_batch(input_batch, scheduler_output) + + self.assertFalse(modified) + input_batch.swap_states.assert_not_called() + + def test_reorder_batch_without_torchair_graph(self): + ascend_config = MagicMock() + ascend_config.torchair_graph_config = MagicMock() + ascend_config.torchair_graph_config.enabled = False + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", + return_value=ascend_config): + builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_device) + + input_batch = MagicMock() + input_batch.req_ids = [0, 1, 2, 3] + + scheduler_output = MagicMock() + scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2} + scheduler_output.scheduled_spec_decode_tokens = { + 0: [], + 1: [1], + 2: [], + 3: [] + } + + input_batch.swap_states = MagicMock() + + modified = builder.reorder_batch(input_batch, scheduler_output) + + self.assertTrue(modified) + input_batch.swap_states.assert_called_once_with(1, 2) + + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): + ascend_config = MagicMock() + mock_ascend_config.return_value = ascend_config + ascend_config.torchair_graph_config.enabled = False + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_device) + block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) + + result = builder._get_graph_runner_block_tables(3, block_tables) + self.assertEqual(result.shape[0], 3) + self.assertEqual(result.shape[1], 64) + self.assertTrue(torch.equal(result[:, :10], block_tables)) + + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config): + ascend_config = MagicMock() + mock_ascend_config.return_value = ascend_config + ascend_config.torchair_graph_config.enabled = False + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 64 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_device) + block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) + + result = builder._get_graph_runner_block_tables(3, block_tables) + self.assertEqual(result.shape[0], 3) + self.assertEqual(result.shape[1], 4) + self.assertTrue(torch.equal(result, block_tables[:, :4])) + + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def test_get_graph_runner_block_tables_from_numpy(self, + mock_ascend_config): + ascend_config = MagicMock() + mock_ascend_config.return_value = ascend_config + ascend_config.torchair_graph_config.enabled = False + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_device) + + block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) + + result = builder._get_graph_runner_block_tables(3, block_tables) + + self.assertEqual(result.shape[0], 3) + self.assertEqual(result.shape[1], 64) + self.assertTrue(torch.equal(result[:, :10], block_tables)) + + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def test_build_dummy(self, mock_ascend_config): + ascend_config = MagicMock() + mock_ascend_config.return_value = ascend_config + ascend_config.torchair_graph_config.enabled = False + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_device = 'cpu' + + builder = AscendMLATorchairMetadataBuilder( + mock_vllm_config, + mock_device, + metadata_cls=AscendMLATorchairMetadata) + builder.rope_dim = 64 + + with patch.object(builder, + "_get_graph_runner_block_tables", + side_effect=lambda x, y: y): + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=3, + num_actual_tokens=3, + decode_token_per_req=1, + actual_seq_lengths_q=[0, 1, 2], + attn_mask=torch.zeros((1, 1), dtype=torch.bool), + spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool), + ) + metadata = builder.build_torchair_graph_dummy(common_attn_metadata) + + sin_golden = torch.ones(3, + 1, + 1, + 64, + dtype=torch.float16, + device=mock_device) + cos_golden = torch.ones(3, + 1, + 1, + 64, + dtype=torch.float16, + device=mock_device) + + self.assertIsInstance(metadata, AscendMLATorchairMetadata) + self.assertEqual(metadata.num_input_tokens, 3) + self.assertEqual(metadata.num_actual_tokens, 3) + self.assertEqual(metadata.num_decodes, 1) + self.assertEqual(metadata.num_decode_tokens, 1) + self.assertEqual(metadata.num_prefills, 0) + self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly) + self.assertIsNone(metadata.prefill) + self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata) + self.assertEqual(metadata.block_tables.shape[0], 3) + self.assertEqual(metadata.block_tables.shape[1], 64) + self.assertEqual(metadata.seq_lens.shape[0], 3) + self.assertEqual(metadata.slot_mapping.shape[0], 3) + self.assertEqual(metadata.query_start_loc.shape[0], 3) + assert torch.equal(sin_golden, metadata.decode.sin) + assert torch.equal(cos_golden, metadata.decode.cos) + + +class TestAscendMLATorchairImpl(TestBase): + + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_tensor_model_parallel_world_size", + return_value=2) + @patch("vllm.config.get_current_vllm_config") + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp): + mock_tp.world_size = 2 + ascend_config.torchair_graph_config.enabled = True + ascend_config.torchair_graph_config.enable_kv_nz = False + speculative_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + + num_heads = 256 + head_size = 1024 + scale = 0.1 + num_kv_heads = 8 + kv_cache_dtype = "auto" + + kv_a_layernorm = MagicMock() + kv_a_layernorm.weight = torch.randn(96) + kv_a_layernorm.variance_epsilon = 1e-6 + kwargs = { + "q_lora_rank": 64, + "kv_lora_rank": 32, + "qk_nope_head_dim": 64, + "qk_rope_head_dim": 32, + "qk_head_dim": 96, + "v_head_dim": 128, + "rotary_emb": MagicMock(), + "q_proj": MagicMock(), + "kv_b_proj": MagicMock(), + "o_proj": MagicMock(), + "kv_a_proj_with_mqa": MagicMock(), + "kv_a_layernorm": kv_a_layernorm, + } + + self.impl = AscendMLATorchairImpl(num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=kv_cache_dtype, + blocksparse_params=None, + logits_soft_cap=None, + attn_type=None, + kv_sharing_target_layer_name=None, + **kwargs) + + def test_init(self): + self.assertEqual(self.impl.num_heads, 256) + self.assertEqual(self.impl.head_size, 1024) + self.assertEqual(self.impl.scale, 0.1) + self.assertEqual(self.impl.num_kv_heads, 8) + self.assertEqual(self.impl.kv_cache_dtype, "auto") + self.assertEqual(self.impl.q_lora_rank, 64) + self.assertEqual(self.impl.kv_lora_rank, 32) + self.assertEqual(self.impl.qk_nope_head_dim, 64) + self.assertEqual(self.impl.qk_rope_head_dim, 32) + self.assertEqual(self.impl.qk_head_dim, 96) + self.assertEqual(self.impl.v_head_dim, 128) + self.assertIsNotNone(self.impl.rotary_emb) + self.assertIsNotNone(self.impl.q_proj) + self.assertIsNotNone(self.impl.kv_b_proj) + self.assertIsNotNone(self.impl.o_proj) + self.assertIsNotNone(self.impl.kv_a_proj_with_mqa) + self.assertIsNotNone(self.impl.kv_a_layernorm) + self.assertEqual(self.impl.num_queries_per_kv, 32) + self.assertEqual(self.impl.tp_size, 2) + self.assertTrue(self.impl.torchair_graph_enabled) + + def test_v_up_proj_and_o_proj(self): + batch_size = 4 + x = torch.randn(batch_size, self.impl.num_heads, + self.impl.kv_lora_rank) + + self.impl.o_proj.return_value = (torch.randn( + batch_size, self.impl.num_heads * self.impl.v_head_dim), ) + if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None: + self.impl.W_UV = torch.randn(self.impl.num_heads, + self.impl.kv_lora_rank, + self.impl.v_head_dim) + result = self.impl._v_up_proj_and_o_proj(x) + + self.assertEqual(result.shape[0], batch_size) + self.assertEqual(result.shape[1], + self.impl.num_heads * self.impl.v_head_dim) + + def test_q_proj_and_k_up_proj(self): + batch_size = 4 + x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim) + q_proj_output = torch.randn(batch_size, self.impl.num_heads, + self.impl.qk_head_dim) + self.impl.q_proj.return_value = (q_proj_output, ) + if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None: + self.impl.W_UK_T = torch.randn(self.impl.num_heads, + self.impl.qk_nope_head_dim, + self.impl.kv_lora_rank) + result = self.impl._q_proj_and_k_up_proj(x) + ql_nope, q_pe = result + self.assertEqual(ql_nope.shape[0], batch_size) + self.assertEqual(ql_nope.shape[1], self.impl.num_heads) + self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank) + self.assertEqual(q_pe.shape[0], batch_size) + self.assertEqual(q_pe.shape[1], self.impl.num_heads) + self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim) + + def test_process_weights_after_loading(self): + layer = MagicMock(spec=LinearBase) + layer.input_size_per_partition = 10 + quant_method = MagicMock() + apply = MagicMock() + quant_method.apply = apply + layer.quant_method = quant_method + shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim + + self.impl.v_head_dim) + shape_1 = self.impl.kv_lora_rank + layer.weight = torch.randn(shape_0, shape_1) + self.impl.kv_b_proj = layer + apply.return_value = layer.weight.T + self.impl.process_weights_after_loading(torch.bfloat16) + + self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads) + self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim) + self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank) + + self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads) + self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank) + self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim) + + def test_compute_prefill_context_none(self): + batch_size = 4 + kv_cache = torch.randn(10, 1, 1, 192) + query = torch.randn(batch_size, self.impl.num_heads, + self.impl.qk_head_dim) + metadata = MagicMock() + metadata.prefill = None + prefix_out = torch.randn(2, 16, 128) + prefix_lse = torch.randn(2, 16, 8) + out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, + metadata, prefix_out, + prefix_lse) + + self.assertTrue(torch.equal(prefix_out, out)) + self.assertTrue(torch.equal(prefix_lse, lse)) + + @patch("torch_npu.atb.npu_paged_cache_load") + @patch("torch_npu.atb.npu_ring_mla") + def test_compute_prefill_context(self, mock_ring, mock_load): + S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim + _, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim + latent_kv_dim = self.impl.kv_lora_rank + num_blocks, block_size = 100, 20 + query = torch.randn(S, N, D) + kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) + kv_cache_1 = torch.randn(num_blocks, block_size, N, D) + kv_cache = [kv_cache_0, kv_cache_1] + prefix_out = torch.randn(S, N, 128) + prefix_lse = torch.randn(S, N) + + self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), ) + + chunk_ctx = MagicMock() + chunk_ctx.seq_tot = [8] + chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.starts = [torch.tensor([0])] + + prefill_meta = MagicMock() + prefill_meta.chunked_context = chunk_ctx + prefill_meta.query_lens = [8] + prefill_meta.block_table = torch.randint(0, 100, (S, 4)) + + meta = MagicMock() + meta.prefill = prefill_meta + + out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, + meta, prefix_out, + prefix_lse) + + mock_load.assert_called_once() + mock_ring.assert_called_once() + + self.assertEqual(out.shape, prefix_out.shape) + self.assertEqual(lse.shape, prefix_lse.shape) + + @patch("torch_npu.npu_kv_rmsnorm_rope_cache") + def test_exec_kv(self, mock_kv_cache): + batch_size = 2 + hidden = torch.randn(batch_size, 128) + cos = torch.randn(batch_size, 32) + sin = torch.randn(batch_size, 32) + kv_cache = (torch.randn( + 4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), + torch.randn( + 4, 8, + self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)) + slots = torch.arange(batch_size, dtype=torch.long) + + proj_out = torch.randn( + batch_size, self.impl.num_kv_heads, 1, + self.impl.kv_lora_rank + self.impl.qk_rope_head_dim) + self.impl.kv_a_proj_with_mqa.return_value = (proj_out, ) + + mock_kv_cache.return_value = (torch.randn(batch_size, + self.impl.num_kv_heads, 1, + self.impl.qk_rope_head_dim), + torch.randn(batch_size, + self.impl.num_kv_heads, 1, + self.impl.kv_lora_rank), + None, None) + + k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots) + + self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden) + mock_kv_cache.assert_called_once() + self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1, + self.impl.qk_rope_head_dim)) + self.assertEqual( + k_nope.shape, + (batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank)) + self.assertEqual(kv.shape, + (batch_size, self.impl.num_kv_heads, 1, + self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)) + + @patch("torch_npu.npu_kv_rmsnorm_rope_cache") + def test_exec_kv_prefill(self, mock_kv): + B, N, S, H = 2, self.impl.num_kv_heads, 1, 128 + hidden_states = torch.randn(B, N, S, H) + cos = torch.randn(B, S, 32) + sin = torch.randn(B, S, 32) + kv_cache = ( + torch.randn(100, 8, + self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), + torch.randn(100, 8, + self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), + ) + + slots = torch.arange(B * S, dtype=torch.long) + + proj_out = torch.randn( + B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim) + self.impl.kv_a_proj_with_mqa.return_value = (proj_out, ) + + mock_kv.return_value = (None, None, + torch.randn(B, self.impl.num_kv_heads, S, + self.impl.qk_rope_head_dim), + torch.randn(B, self.impl.num_kv_heads, S, + self.impl.kv_lora_rank)) + + k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin, + kv_cache, slots) + + self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states) + mock_kv.assert_called_once() + + self.assertEqual( + k_pe.shape, + (B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim)) + self.assertEqual( + k_nope.shape, + (B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank)) + + @patch("torch_npu.npu_interleave_rope") + def test_rope_single(self, mock_rope): + B, N, D = 2, 16, 1024 + x = torch.randn(B, N, D) + cos = torch.randn(B, N, 1, D) + sin = torch.randn(B, N, 1, D) + mock_rope.return_value = x.view(B, N, 1, D) + result = self.impl.rope_single(x, cos, sin) + self.assertEqual(result.shape[0], B) + self.assertEqual(result.shape[1], N) + self.assertEqual(result.shape[2], D) + mock_rope.assert_called_once() + + @patch( + "vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj" + ) + @patch("torch_npu._npu_paged_attention_mla") + def test_forward_decode_without_graph(self, mock_page_attention_mla, + mock_up_proj): + self.impl.running_in_graph = False + self.impl.running_chunkprefilll_with_torchair = False + num_tokens = 100 + num_blocks = 256 + block_size = 4 + q_nope = torch.randn(num_tokens, self.impl.num_heads, + self.impl.qk_nope_head_dim) + q_pe = torch.randn(num_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim) + kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size, + self.impl.num_heads, + self.impl.kv_lora_rank) + metadata = MagicMock() + metadata.decode = MagicMock() + metadata.decode.block_table = MagicMock() + metadata.decode.seq_lens = 10 + mock_page_attention_mla.return_value = torch.randn( + num_tokens, self.impl.num_heads, self.impl.kv_lora_rank) + mock_up_proj.return_value = torch.randn(num_tokens, + self.impl.num_heads, + self.impl.v_head_dim) + result = self.impl._forward_decode(q_nope, q_pe, None, None, + kv_c_and_k_pe_cache, metadata) + self.assertEqual(result.shape[0], num_tokens) + self.assertEqual(result.shape[1], self.impl.num_heads) + self.assertEqual(result.shape[2], self.impl.v_head_dim) + mock_up_proj.assert_called_once() + mock_page_attention_mla.assert_called_once() + + @patch( + "vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill" + ) + @patch("torch_npu._npu_reshape_and_cache") + def test_forward_without_graph(self, _, mock_forward_prefill): + self.impl.running_in_graph = False + self.impl.torchair_graph_enabled = False + + num_tokens = 100 + num_blocks = 256 + block_size = 4 + rotary_emb_return_value = (torch.randn(num_tokens, 16, + self.impl.kv_lora_rank), + torch.randn(0, 1, self.impl.kv_lora_rank)) + self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value + self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn( + 1, num_blocks, 128) + + hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank) + hidden_states_or_kv_c_normed = torch.randn(num_tokens, + self.impl.kv_lora_rank) + k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim) + kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads, + self.impl.kv_lora_rank), + torch.randn(num_blocks, block_size, self.impl.num_heads, + self.impl.qk_rope_head_dim)) + output = torch.randn(num_tokens, self.impl.num_heads, + self.impl.v_head_dim) + + metadata = MagicMock() + metadata.num_decodes = 0 + metadata.num_prefills = num_tokens + mock_forward_prefill.return_value = torch.randn( + 0, self.impl.num_heads * self.impl.v_head_dim) + result = self.impl.forward(None, hidden_states_or_q_c, + hidden_states_or_kv_c_normed, k_pe, + kv_cache, metadata, output, False) + self.assertEqual(result.shape[0], num_tokens) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 72a2d4f..fcad4c8 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,14 +1,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar -import numpy as np import torch import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, @@ -24,9 +22,6 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, - npu_stream_switch, npu_wait_tensor) -from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -212,8 +207,6 @@ class AscendMLAMetadataBuilder: dtype=self.model_config.dtype, device=device, ) - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None @@ -231,20 +224,10 @@ class AscendMLAMetadataBuilder: for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - # For torch air graph mode we treat spec decoding as decode. - if self.torchair_graph_enabled: - if num_tokens - num_spec_tokens == 1: - decodes.append(i) - else: - prefills.append(i) - # For eager mode we treat spec decoding as chunked prefill. + if num_tokens == 1: + decodes.append(i) else: - if num_tokens == 1: - decodes.append(i) - else: - prefills.append(i) + prefills.append(i) # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are @@ -277,99 +260,6 @@ class AscendMLAMetadataBuilder: # better way of doing this return modified_batch - def _get_graph_runner_block_tables( - self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - max_blocks = self.max_blocks - - graph_block_tables = torch.zeros((num_seqs, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] - else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:, :max_blocks] - - def build_torchair_graph_dummy( - self, - common_attn_metadata: TorchairCommonAttentionMetadata, - ) -> AscendMLAMetadata: - device = self.device - num_reqs = common_attn_metadata.num_reqs - block_table = torch.zeros((num_reqs, self.max_blocks), - dtype=torch.int32, - device=device) - block_table = self._get_graph_runner_block_tables( - num_reqs, block_table) - num_tokens = num_reqs * common_attn_metadata.decode_token_per_req - seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) - seq_lens_list = [0] * num_reqs - input_positions = torch.zeros(num_tokens, - dtype=torch.int32, - device=device).long() - slot_mapping = torch.full((num_tokens, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - query_start_loc = torch.full((num_reqs, ), - -1, - dtype=torch.int32, - device=device) - sin = torch.ones(num_tokens, - 1, - 1, - self.rope_dim, - dtype=self.model_config.dtype, - device=device) - cos = torch.ones(num_tokens, - 1, - 1, - self.rope_dim, - dtype=self.model_config.dtype, - device=device) - if self.vllm_config.speculative_config is not None and\ - self.vllm_config.speculative_config.method == 'deepseek_mtp': - attn_state = AscendAttentionState.SpecDecoding - num_decode_tokens = 2 - else: - attn_state = AscendAttentionState.DecodeOnly - num_decode_tokens = 1 - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=1, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=common_attn_metadata. - actual_seq_lengths_q[:num_reqs], - sin=sin, - cos=cos, - ) - return self.metadata_cls( # type: ignore - num_input_tokens=common_attn_metadata.num_actual_tokens, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=1, - num_decode_tokens=num_decode_tokens, - num_prefills=0, - attn_mask=common_attn_metadata.attn_mask, - attn_state=attn_state, - prefill=None, - decode=decode_metadata, - query_start_loc=query_start_loc, - seq_lens=seq_lens, - block_tables=block_table, - ) - def build( self, common_attn_metadata: AscendCommonAttentionMetadata, @@ -379,14 +269,8 @@ class AscendMLAMetadataBuilder: num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - decode_threshold = common_attn_metadata.decode_token_per_req - else: - # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding - decode_threshold = 1 + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) assert num_decodes + num_prefills == num_reqs @@ -489,57 +373,14 @@ class AscendMLAMetadataBuilder: ) decode_metadata = None - graph_pad_size = common_attn_metadata.graph_pad_size - use_torchair_graph = graph_pad_size != -1 if num_decodes > 0: actual_seq_lengths_q = query_start_loc[1:].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decode_tokens] input_positions = input_positions[:num_decode_tokens] block_table = block_table[:num_decode_tokens, ...] - if use_torchair_graph and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - num_reqs_pad_size = 0 - num_token_pad_size = 0 - if graph_pad_size != 0: - pad_value = 0 - num_token_pad_size = graph_pad_size - num_decode_tokens - num_reqs_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) - padded_seq_lens = seq_lens.tolist( - ) + [pad_value] * num_reqs_pad_size - else: - padded_seq_lens = seq_lens.tolist() - - seq_lens = torch.from_numpy( - np.array(padded_seq_lens).astype(np.int32)) - seq_lens_list = padded_seq_lens - slot_padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, slot_padding]) - block_table_padding = torch.zeros( - (num_reqs_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], - dim=0) - block_table = self._get_graph_runner_block_tables( - num_reqs + num_reqs_pad_size, block_table) - position_padding = torch.zeros(num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat( - [input_positions, position_padding]) - actual_seq_lengths_q = query_start_loc[1:].tolist( - ) + common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] - else: - seq_lens_list = seq_lens.tolist() + seq_lens_list = seq_lens.tolist() + # TODO(xyx): whether this block is necessary without torchair # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) batch_size = slot_mapping.size(0) if actual_seq_lengths_q[-1] != batch_size \ @@ -624,8 +465,6 @@ class AscendMLAImpl(MLAAttentionImpl): self.tp_size = get_tensor_model_parallel_world_size() ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp # Adapt torch air graph mode with spec decoding. @@ -634,21 +473,14 @@ class AscendMLAImpl(MLAAttentionImpl): self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 - def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - if hasattr(self, "running_in_graph") and not self.running_in_graph: - return x - MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB - npu_prefetch(self.o_proj.weight, - x, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) - return self.o_proj(x, is_prefill=False)[0] + return x # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): @@ -915,77 +747,6 @@ class AscendMLAImpl(MLAAttentionImpl): return attn_output - def exec_kv( - self, - hidden_states: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: Tuple, - slots: torch.Tensor, - ): - - B = hidden_states.shape[0] - N = self.num_kv_heads - S = 1 - kv = self.kv_a_proj_with_mqa(hidden_states)[0] - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - ) - return k_pe, k_nope, kv - - def exec_kv_prefill( - self, - hidden_states: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: Tuple, - slots: torch.Tensor, - ): - - B = hidden_states.shape[0] - N = self.num_kv_heads - S = 1 - kv = self.kv_a_proj_with_mqa(hidden_states)[0] - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" - _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - is_output_kv=True, - ) - return k_pe, k_nope - - def rope_single( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - B, N, D = x.shape - S = 1 - x = x.view(B, N, S, D) - x = torch_npu.npu_interleave_rope(x, cos, sin) - return x.view(B, N, D) - def _forward_decode( self, q_nope: torch.Tensor, @@ -994,100 +755,41 @@ class AscendMLAImpl(MLAAttentionImpl): k_pe: torch.Tensor, kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, - enable_multistream_mla: bool = False, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None num_tokens = q_nope.size(0) - if self.running_in_graph or self.running_chunkprefilll_with_torchair: - # shape of knope/k_pe for npu graph mode should be: - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] - block_size = kv_c_and_k_pe_cache[0].shape[1] - actual_seq_lengths = None - if self.enable_kv_nz: - k_nope = k_nope.view(-1, self.num_kv_heads, - self.kv_lora_rank // 16, block_size, 16) - k_pe = k_pe.view(-1, self.num_kv_heads, - self.qk_rope_head_dim // 16, block_size, 16) - input_layout = "BSND" - else: - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) - input_layout = "BNSD" - - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 - input_layout = "TND" - # [bs * q_seq_len, num_heads_per_rank, dim] - q_nope = q_nope.view(num_tokens, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, -1) - sparse_mode = 3 - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore - actual_seq_lengths = decode_meta.actual_seq_lengths_q - else: - if self.enable_kv_nz: - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - sparse_mode = 0 - spec_attn_mask = None - - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout=input_layout, - atten_mask=spec_attn_mask, - sparse_mode=sparse_mode, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=decode_meta.block_table, - block_size=block_size, - actual_seq_lengths_kv=decode_meta.seq_lens_list, - actual_seq_lengths=actual_seq_lengths) + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become + # public available + assert len(kv_c_and_k_pe_cache) > 1 + if envs_ascend.VLLM_ASCEND_MLA_PA: + attn_output = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, q_pe, kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1], + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, self.num_heads, self.scale, + self.num_kv_heads) else: - # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will - # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become - # public available - assert len(kv_c_and_k_pe_cache) > 1 - if envs_ascend.VLLM_ASCEND_MLA_PA: - attn_output = torch_npu.atb.npu_multi_head_latent_attention( - q_nope, q_pe, kv_c_and_k_pe_cache[0], - kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, self.num_heads, self.scale, - self.num_kv_heads) - else: - q = torch.cat([q_nope, q_pe], dim=-1) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) - k_cache = torch.cat( - [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=k_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode. - block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=k_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode.block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: - return self._v_up_proj_and_o_proj(attn_output, - enable_multistream_mla) + return self._v_up_proj_and_o_proj(attn_output) else: current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): @@ -1103,19 +805,14 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, - enable_multistream_mla: bool = False, ckq: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. return output - self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill num_actual_toks = attn_metadata.num_actual_tokens - if k_pe is None and not self.running_in_graph: + if k_pe is None: kv_c, k_pe = self.kv_a_proj_with_mqa( hidden_states_or_kv_c_normed)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -1128,134 +825,55 @@ class AscendMLAImpl(MLAAttentionImpl): has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - if not self.running_in_graph: - # Inputs and outputs may be padded for CUDA graphs - output_padded = output - output = output[:num_actual_toks, ...] - if not self.torchair_graph_enabled: - kv_c_normed = kv_c_normed[:num_actual_toks, ...] - prefill_k_c_normed = kv_c_normed[num_decode_tokens:] - if not self.running_in_graph: - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] - # if not self.torchair_graph_enabled: - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] - else: - decode_hs_or_q_c = hidden_states_or_q_c + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] if has_decode: decode_k_nope = None assert attn_metadata.decode is not None - if self.running_in_graph or self.running_chunkprefilll_with_torchair: - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin - if self.running_chunkprefilll_with_torchair: - decode_hs = ( - hidden_states_or_kv_c_normed[:num_decode_tokens]) - slots = attn_metadata.slot_mapping[:num_decode_tokens] - decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( - decode_hs, cos, sin, kv_cache, slots) - else: - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - npu_wait_tensor(hidden_states_or_kv_c_normed, - ckq, - enabled=enable_multistream_mla) - decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) - # Without explicitly controlling the order, IndexByTensor operations - # would be placed after `matmul W_KV_T` hindering the overlapping of - # KvRmsNormRopeCache and SingleRope. - npu_wait_tensor(decode_hs_or_q_c, - cos, - enabled=enable_multistream_mla) - npu_wait_tensor(decode_hs_or_q_c, - sin, - enabled=enable_multistream_mla) - npu_wait_tensor(decode_hs_or_q_c, - decode_kv, - enabled=enable_multistream_mla) - decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) - if self.running_in_graph: - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - npu_wait_tensor(decode_q_pe, - decode_k_pe, - enabled=enable_multistream_mla) - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - elif self.running_chunkprefilll_with_torchair: - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - else: - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), + decode_k_pe, + max_seq_len=attn_metadata.decode.max_seq_lens) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] - if self.torchair_graph_enabled: - num_tokens = prefill_hs_or_q_c.shape[0] - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - - prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) - prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - prefill_hs, cos, sin, kv_cache, - attn_metadata.slot_mapping[num_decode_tokens:]) - - kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope - prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, - -1) - prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) - else: - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), + prefill_k_pe, + max_seq_len=attn_metadata.prefill.max_seq_lens) assert len( kv_cache ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" - if self.torchair_graph_enabled: - if kv_cache[0].numel() > 0 and has_prefill: - slots = attn_metadata.slot_mapping - # NOTE: Separate the kv cache in advance to avoid OOM or other issues - torch_npu._npu_reshape_and_cache( - key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), - value=prefill_k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=slots[num_decode_tokens:]) - else: - kv_c_normed = kv_c_normed.view( - [num_actual_toks, self.num_kv_heads, -1]) - torch_npu._npu_reshape_and_cache( - key=kv_c_normed, - value=k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=attn_metadata.slot_mapping) - if not self.running_in_graph: - o_proj_input_shape = (num_actual_toks, - self.num_heads * self.v_head_dim) - o_proj_input = torch.empty(o_proj_input_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) + o_proj_input_shape = (num_actual_toks, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy @@ -1274,17 +892,9 @@ class AscendMLAImpl(MLAAttentionImpl): o_proj_input[num_decode_tokens:] = output_prefill if has_decode: - if self.running_in_graph: - return self._forward_decode(decode_ql_nope, decode_q_pe, - decode_k_nope, decode_k_pe, - kv_cache, attn_metadata, - enable_multistream_mla) - else: - output_decode = self._forward_decode(decode_ql_nope, - decode_q_pe, - decode_k_nope, - decode_k_pe, kv_cache, - attn_metadata) + output_decode = self._forward_decode(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe, + kv_cache, attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): @@ -1293,23 +903,13 @@ class AscendMLAImpl(MLAAttentionImpl): o_proj_input[:num_decode_tokens] = output_decode current_ms_metadata = get_multistream_comm_context() - MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB if current_ms_metadata is None: - npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) - output[...] = self.o_proj( o_proj_input, is_prefill=True, is_force_scatter=self.enable_shared_expert_dp)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): - npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) output[...] = self.o_proj( o_proj_input, is_prefill=True, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 27b922b..6075a79 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -235,12 +235,18 @@ class NPUPlatform(Platform): raise ValueError("vLLM Ascend does not support V0 engine.") use_torchair = get_ascend_config().torchair_graph_config.enabled - if use_mla: - return "vllm_ascend.attention.mla_v1.AscendMLABackend" - elif use_torchair: - return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend" - else: - return "vllm_ascend.attention.attention_v1.AscendAttentionBackend" + # choose attention backend based on use_mla and use_torchair + backend_map = { + (True, True): + "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend", + (True, False): + "vllm_ascend.attention.mla_v1.AscendMLABackend", + (False, True): + "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend", + (False, False): + "vllm_ascend.attention.attention_v1.AscendAttentionBackend" + } + return backend_map[(use_mla, use_torchair)] @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py new file mode 100644 index 0000000..10718b7 --- /dev/null +++ b/vllm_ascend/torchair/torchair_mla.py @@ -0,0 +1,1319 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn as nn +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.context import get_multistream_comm_context +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, + npu_stream_switch, npu_wait_tensor) +from vllm_ascend.utils import npu_prefetch +from vllm_ascend.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendMLATorchairBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_MLA_TORCHAIR" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendMLATorchairMetadata + + @staticmethod + def get_builder_cls(): + return AscendMLATorchairMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendMLATorchairImpl + + +@dataclass +class AscendMLATorchairPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class TorchairChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] + seq_lens: list[int] + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + chunked_context: Optional[TorchairChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None + + +@dataclass +class AscendMLATorchairDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: Optional[list[int]] = None + attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None + + +@dataclass +class AscendMLATorchairMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendMLATorchairDecodeMetadata] = None + prefill: Optional[AscendMLATorchairPrefillMetadata] = None + enable_dbo_across_dp: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMLATorchairMetadata"]: + """Split metadata for multi-stream with AscendMLATorchairMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLATorchairMetadata, + ) + + +M = TypeVar("M", bound=AscendMLATorchairMetadata) + + +class AscendMLATorchairMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendMLATorchairMetadata] = None): + self.metadata_cls: Optional[AscendMLATorchairMetadata] = metadata_cls \ + if metadata_cls is not None else AscendMLATorchairMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + # For torch air graph mode we treat spec decoding as decode. + if self.torchair_graph_enabled: + if num_tokens - num_spec_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + # For eager mode we treat spec decoding as chunked prefill. + else: + if num_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] + + def build_torchair_graph_dummy( + self, + common_attn_metadata: TorchairCommonAttentionMetadata, + ) -> AscendMLATorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = [0] * num_reqs + input_positions = torch.zeros(num_tokens, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_tokens, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + sin = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + cos = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 + decode_metadata = AscendMLATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=1, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=common_attn_metadata. + actual_seq_lengths_q[:num_reqs], + sin=sin, + cos=cos, + ) + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=1, + num_decode_tokens=num_decode_tokens, + num_prefills=0, + attn_mask=common_attn_metadata.attn_mask, + attn_state=attn_state, + prefill=None, + decode=decode_metadata, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_tables=block_table, + ) + + def build( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLATorchairMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + device, + non_blocking= + True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + prefill_metadata = AscendMLATorchairPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens[tokens_start:], + seq_lens=seq_lens, + context_lens=seq_lens[tokens_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size + use_torchair_graph = graph_pad_size != -1 + if num_decodes > 0: + actual_seq_lengths_q = query_start_loc[1:].tolist() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decode_tokens] + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decode_tokens, ...] + if use_torchair_graph and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + num_reqs_pad_size = 0 + num_token_pad_size = 0 + if graph_pad_size != 0: + pad_value = 0 + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size + else: + padded_seq_lens = seq_lens.tolist() + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)) + seq_lens_list = padded_seq_lens + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_reqs_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_reqs + num_reqs_pad_size, block_table) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + actual_seq_lengths_q = query_start_loc[1:].tolist( + ) + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + else: + seq_lens_list = seq_lens.tolist() + # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) + batch_size = slot_mapping.size(0) + if actual_seq_lengths_q[-1] != batch_size \ + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + actual_seq_lengths_q[-1] = batch_size + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendMLATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + ) + + +class AscendMLATorchairImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + # Adapt torch air graph mode with spec decoding. + speculative_config = get_current_vllm_config().speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if hasattr(self, "running_in_graph") and not self.running_in_graph: + return x + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + npu_prefetch(self.o_proj.weight, + x, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + return self.o_proj(x, is_prefill=False)[0] + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + def _compute_prefill_context( + self, + query: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + rope_dim: int, + attn_metadata: AscendMLATorchairMetadata, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + ): + assert len(kv_c_and_k_pe_cache) > 1 + prefill_metadata = attn_metadata.prefill + if prefill_metadata is None or prefill_metadata.chunked_context is None: + return prefix_output, prefix_lse + + iters = len(prefill_metadata.chunked_context.seq_tot) + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + + seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + + seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len = torch.stack([seq_len1, seq_len2]) + kv_c_normed = torch.empty(toks, + num_heads, + latent_kv_dim, + dtype=query.dtype, + device=query.device) + k_pe = torch.empty(toks, + num_heads, + rope_dim, + dtype=query.dtype, + device=query.device) + + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + seq_len2.to(query.device), + seq_starts=prefill_metadata.chunked_context.starts[i], + key=kv_c_normed, + value=k_pe, + ) + + kv_c_normed = kv_c_normed.squeeze() + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + mask = torch.triu( + torch.ones(512, 512, device=query.device, dtype=query.dtype), + 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) + return prefix_output, prefix_lse + + def _forward_prefill( + self, + query: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLATorchairMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 + + num_tokens = query.size(0) + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=query.dtype, + device=query.device) + k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache + ascend_config = get_ascend_config() + + if attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit + ] and not ascend_config.chunked_prefill_for_mla: + attn_output_torch = torch.empty(num_tokens, + self.num_heads * self.v_head_dim, + dtype=query.dtype, + device=query.device) + # current requests is chunked in prefill, disable flash attention with chunked prefill + vanilla_chunked_prefill_mla( + output=attn_output_torch, + query=query, + kv_cache=kv_c_and_k_pe_cache, + block_tables=attn_metadata.prefill.block_table, + query_lens=attn_metadata.prefill.query_lens, + context_lens=attn_metadata.prefill.context_lens, + kv_b_proj=self.kv_b_proj, + max_query_len=attn_metadata.prefill.max_query_len, + max_context_len=attn_metadata.prefill.max_seq_lens, + nope_dim=self.qk_nope_head_dim, + rope_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + scale=self.scale, + alibi_slopes=None, + causal=True) + elif attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit + ]: + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=query.device) + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + mask = torch.triu( + torch.ones(512, 512, device=query.device, dtype=query.dtype), + 1) # 512: mask only support 512 + if attn_metadata.num_prefills > 1: + mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, + 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=mask, + seqlen=torch.tensor(attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + key = torch.cat((k_nope, k_pe), dim=-1) + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=attn_metadata.attn_mask, + seq_len=attn_metadata.prefill.context_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=attn_output) + attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) + else: + raise RuntimeError( + "Unexpected path reached, AscendMLATorchairImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" + ) + attn_output = attn_output.reshape( + [num_tokens, self.num_heads * self.v_head_dim]) + if attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit + ] and not ascend_config.chunked_prefill_for_mla: + attn_output = attn_output_torch + + return attn_output + + def exec_kv( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + return k_pe, k_nope, kv + + def exec_kv_prefill( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, + ) + return k_pe, k_nope + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLATorchairMetadata, + enable_multistream_mla: bool = False, + ) -> torch.Tensor: + decode_meta = attn_metadata.decode + assert decode_meta is not None + num_tokens = q_nope.size(0) + if self.running_in_graph or self.running_chunkprefilll_with_torchair: + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + block_size = kv_c_and_k_pe_cache[0].shape[1] + actual_seq_lengths = None + if self.enable_kv_nz: + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" + + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + input_layout = "TND" + # [bs * q_seq_len, num_heads_per_rank, dim] + q_nope = q_nope.view(num_tokens, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=decode_meta.block_table, + block_size=block_size, + actual_seq_lengths_kv=decode_meta.seq_lens_list, + actual_seq_lengths=actual_seq_lengths) + else: + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become + # public available + assert len(kv_c_and_k_pe_cache) > 1 + if envs_ascend.VLLM_ASCEND_MLA_PA: + attn_output = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, q_pe, kv_c_and_k_pe_cache[0], + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, self.num_heads, self.scale, + self.num_kv_heads) + else: + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=k_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode. + block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj_and_o_proj(attn_output, + enable_multistream_mla) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj_and_o_proj(attn_output) + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: Tuple[torch.Tensor], + attn_metadata: M, + output: Optional[torch.Tensor] = None, + enable_multistream_mla: bool = False, + ckq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill + num_actual_toks = attn_metadata.num_actual_tokens + if k_pe is None and not self.running_in_graph: + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + else: + kv_c_normed = hidden_states_or_kv_c_normed + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + if not self.running_in_graph: + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + if not self.torchair_graph_enabled: + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + if not self.running_in_graph: + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] + else: + decode_hs_or_q_c = hidden_states_or_q_c + if has_decode: + decode_k_nope = None + assert attn_metadata.decode is not None + if self.running_in_graph or self.running_chunkprefilll_with_torchair: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + if self.running_chunkprefilll_with_torchair: + decode_hs = ( + hidden_states_or_kv_c_normed[:num_decode_tokens]) + slots = attn_metadata.slot_mapping[:num_decode_tokens] + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + decode_hs, cos, sin, kv_cache, slots) + else: + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(hidden_states_or_kv_c_normed, + ckq, + enabled=enable_multistream_mla) + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) + # Without explicitly controlling the order, IndexByTensor operations + # would be placed after `matmul W_KV_T` hindering the overlapping of + # KvRmsNormRopeCache and SingleRope. + npu_wait_tensor(decode_hs_or_q_c, + cos, + enabled=enable_multistream_mla) + npu_wait_tensor(decode_hs_or_q_c, + sin, + enabled=enable_multistream_mla) + npu_wait_tensor(decode_hs_or_q_c, + decode_kv, + enabled=enable_multistream_mla) + + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_hs_or_q_c) + if self.running_in_graph: + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(decode_q_pe, + decode_k_pe, + enabled=enable_multistream_mla) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + elif self.running_chunkprefilll_with_torchair: + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + else: + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), + decode_k_pe, + max_seq_len=attn_metadata.decode.max_seq_lens) + if has_prefill: + assert attn_metadata.prefill is not None + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + if self.torchair_graph_enabled: + num_tokens = prefill_hs_or_q_c.shape[0] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) + + kv_c_normed = prefill_k_nope[:num_actual_toks, ...] + prefill_k_c_normed = prefill_k_nope + prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, + -1) + prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) + else: + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), + prefill_k_pe, + max_seq_len=attn_metadata.prefill.max_seq_lens) + + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" + if self.torchair_graph_enabled: + if kv_cache[0].numel() > 0 and has_prefill: + slots = attn_metadata.slot_mapping + # NOTE: Separate the kv cache in advance to avoid OOM or other issues + torch_npu._npu_reshape_and_cache( + key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), + value=prefill_k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots[num_decode_tokens:]) + else: + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) + if not self.running_in_graph: + o_proj_input_shape = (num_actual_toks, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + if has_prefill: + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + o_proj_input[num_decode_tokens:] = output_prefill + else: + o_proj_input[num_decode_tokens:] = output_prefill + + if has_decode: + if self.running_in_graph: + return self._forward_decode(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe, + kv_cache, attn_metadata, + enable_multistream_mla) + else: + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + o_proj_input[:num_decode_tokens] = output_decode + else: + o_proj_input[:num_decode_tokens] = output_decode + + current_ms_metadata = get_multistream_comm_context() + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + if current_ms_metadata is None: + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + else: + with torch.npu.stream(current_ms_metadata.comm_stream): + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + current_ms_metadata.after_comm_event.record() + del o_proj_input + return output_padded diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b55cc13..d250055 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -92,6 +92,7 @@ from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata +from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, maybe_converting_weight_acl_format) @@ -624,7 +625,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", ) -> dict[str, Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata]]: + AscendTorchairMetadata, AscendMLATorchairMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -736,7 +737,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc[num_reqs + 1:].fill_(-1) attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata]] = {} + AscendTorchairMetadata, + AscendMLATorchairMetadata]] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1000,8 +1002,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata], torch.Tensor, np.ndarray, int, + ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata, + AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int, torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -1466,7 +1468,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens: int, hidden_states: torch.Tensor, attn_metadata: Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata], + AscendTorchairMetadata, + AscendMLATorchairMetadata], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.use_spec_decode: @@ -2540,7 +2543,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens: int, hidden_states: torch.Tensor, attn_metadata: Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata], + AscendTorchairMetadata, + AscendMLATorchairMetadata], ): assert isinstance(self.drafter, MtpProposer) next_token_ids: list[int] = []