diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 06c5dc6d..88d5071d 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -289,6 +289,7 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state._PCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @@ -296,7 +297,8 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_ascend_mla_metadata_builder_build_full_graph( - self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group): + self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -330,7 +332,6 @@ class TestAscendMLAMetadataBuilder(TestBase): builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, mock_device) common_metadata = MagicMock() - model = MagicMock() common_metadata.graph_pad_size = 8 common_metadata.num_reqs = 4 common_metadata.num_actual_tokens = 5 @@ -343,7 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase): block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int() common_metadata.block_table_tensor = block_table common_metadata.prefill_context_parallel_metadata = None - metadata = builder.build(0, common_metadata, model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]), + torch.Tensor([6, 6])) + metadata = builder.build(0, common_metadata) self.assertEqual(metadata.decode.actual_seq_lengths_q, [1, 2, 4, 5, 6, 6, 7, 8]) @@ -526,6 +529,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.kv_cache_spec.head_size = 128 self.kv_cache_spec.num_heads = 32 + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -534,7 +538,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): @patch("torch.npu.is_available") def test_build_prefix_no_cache_metadata(self, mock_npu_available, mock_zeros, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -579,9 +584,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - metadata = builder.build(1, common_attn_metadata, mock_model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), + torch.Tensor(10)) + metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -590,6 +595,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -598,7 +604,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): @patch("torch.npu.is_available") def test_build_chunked_prefix_metadata(self, mock_npu_available, mock_zeros, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -644,9 +651,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - metadata = builder.build(1, common_attn_metadata, mock_model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), + torch.Tensor(10)) + metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -655,11 +662,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_build_decode_only_metadata(self, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -697,9 +706,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - metadata = builder.build(1, common_attn_metadata, mock_model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]), + torch.Tensor([10, 10])) + metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -708,11 +717,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -750,10 +761,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() + mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]), + torch.Tensor([10, 10])) metadata = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model) + common_attn_metadata, AscendAttentionState.DecodeOnly) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -762,11 +773,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa pcp_group = MagicMock(spec=GroupCoordinator) @@ -795,13 +808,11 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), + torch.Tensor(10)) with self.assertRaises(NotImplementedError) as ctx: builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.PrefillNoCache, - mock_model) + common_attn_metadata, AscendAttentionState.PrefillNoCache) self.assertIn( "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state", str(ctx.exception)) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index caa8cec6..dd4c2f5e 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -1,5 +1,5 @@ import sys -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import torch from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase): assert builder.device == device assert builder.vllm_config == vllm_config - def test_ascend_sfa_metadata_builder_build(self): + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() @@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase): common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 - model = MagicMock() - model.model.layers = [MagicMock() for _ in range(10)] - model.model.start_layer = 0 + mock_get_cos_and_sin_mla.return_value = (torch.randn(100), + torch.randn(100)) metadata = builder.build( common_prefix_len=10, common_attn_metadata=common_attn_metadata, - model=model, ) assert isinstance(metadata, AscendSFAMetadata) assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens assert metadata.slot_mapping.shape == (100, 4, 1024) - def test_ascend_sfa_metadata_builder_build_for_graph_capture(self): + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + def test_ascend_sfa_metadata_builder_build_for_graph_capture( + self, mock_get_cos_and_sin_mla): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() @@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase): common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 - model = MagicMock() - model.model.layers = [MagicMock() for _ in range(10)] - model.model.start_layer = 0 + mock_get_cos_and_sin_mla.return_value = (torch.randn(100), + torch.randn(100)) attn_metadata = builder.build_for_graph_capture( common_attn_metadata=common_attn_metadata, attn_state=AscendAttentionState.DecodeOnly, - model=model, ) assert isinstance(attn_metadata, AscendSFAMetadata) diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 9a426ec2..d161c20d 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -20,7 +20,6 @@ from typing import ClassVar, List, Optional, Tuple import numpy as np import torch import torch.distributed as dist -import torch.nn as nn import torch_npu from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, @@ -90,7 +89,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: Optional[nn.Module] = None, + fast_build: bool = False, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 0aef154a..8054afbe 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -20,7 +20,6 @@ from enum import Enum from typing import ClassVar, List, Optional, Tuple, Type import torch -import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -29,7 +28,8 @@ from vllm.attention.backends.registry import (AttentionBackendEnum, from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec @@ -170,7 +170,7 @@ class AscendMetadata: model_runner_type: str = "" -class AscendAttentionMetadataBuilder: +class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.ALWAYS @@ -217,8 +217,8 @@ class AscendAttentionMetadataBuilder: self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: Optional[nn.Module] = None, - ): + fast_build: bool = False, + ) -> AscendMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: @@ -261,7 +261,6 @@ class AscendAttentionMetadataBuilder: self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, - model: Optional[nn.Module] = None, ): if attn_state == AscendAttentionState.DecodeOnly: attn_metadata = self.build( diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 0a3aed14..4ce90cb1 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -4,7 +4,6 @@ import numpy as np import torch import torch.distributed as dist import torch_npu -from torch import nn from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, @@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): understand this class """ - def __init__(self, - kv_cache_spec: MLAAttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendMLAMetadata] = None): + def __init__( + self, + kv_cache_spec: MLAAttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendMLAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device, - metadata_cls) + metadata_cls, supports_dcp_with_varlen) self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group( @@ -92,7 +94,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendPCPMetadata | None: common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert common_long_seq_metadata is not None @@ -121,10 +122,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ): chunked_context_metadata = super().build_chunked_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) if chunked_context_metadata is None: return None @@ -205,12 +205,11 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLAPrefillMetadata: prefill_metadata = super().build_prefill_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) prefill_metadata.pcp_metadata = self.build_cp_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) prefill_metadata.block_table = self.block_table[ self.num_decodes_flatten:, ...] return prefill_metadata @@ -219,10 +218,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLADecodeMetadata: decode_metadata = super().build_decode_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert long_seq_metadata is not None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e529f0bf..6454a294 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, import numpy as np import torch import torch_npu -from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config @@ -13,6 +12,7 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down +from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import MLAAttentionSpec @@ -177,7 +177,7 @@ class AscendMLAMetadata: M = TypeVar("M", bound=AscendMLAMetadata) -class AscendMLAMetadataBuilder: +class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH @@ -186,14 +186,17 @@ class AscendMLAMetadataBuilder: understand this class """ - def __init__(self, - kv_cache_spec: MLAAttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendMLAMetadata] = None): - self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ - if metadata_cls is not None else AscendMLAMetadata # type: ignore + def __init__( + self, + kv_cache_spec: MLAAttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendMLAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): + self.metadata_cls = (metadata_cls if metadata_cls is not None else + AscendMLAMetadata) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device @@ -384,7 +387,7 @@ class AscendMLAMetadataBuilder: self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + fast_build: bool = False, ) -> AscendMLAMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc @@ -400,17 +403,6 @@ class AscendMLAMetadataBuilder: self.slot_mapping = common_attn_metadata.slot_mapping[:self. num_actual_tokens] - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] self.query_lens = query_seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] @@ -420,12 +412,12 @@ class AscendMLAMetadataBuilder: prefill_metadata = None if self.num_prefills > 0: prefill_metadata = self.build_prefill_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) decode_metadata = None if self.num_decodes > 0: decode_metadata = self.build_decode_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=self.num_actual_tokens, @@ -450,7 +442,6 @@ class AscendMLAMetadataBuilder: self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ): if not self.chunked_prefill_enabled: return None @@ -520,7 +511,6 @@ class AscendMLAMetadataBuilder: self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLAPrefillMetadata: query_start_loc = common_attn_metadata.query_start_loc @@ -530,7 +520,7 @@ class AscendMLAMetadataBuilder: ) chunked_context_metadata = self.build_chunked_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) reqs_start = self.num_decodes # prefill_start tokens_start = self.num_decode_tokens max_query_len = self.query_lens[reqs_start:].max().item() @@ -539,12 +529,7 @@ class AscendMLAMetadataBuilder: reqs_start:] - query_start_loc[reqs_start] prefill_input_positions = input_positions[tokens_start:] - cos = self.cos_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) + cos, sin = get_cos_and_sin_mla(prefill_input_positions) return AscendMLAPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, query_lens=self.query_lens[reqs_start:].to(torch.int32), @@ -564,7 +549,6 @@ class AscendMLAMetadataBuilder: self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLADecodeMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -573,7 +557,6 @@ class AscendMLAMetadataBuilder: num_actual_tokens].long( ) - cos, sin = get_cos_and_sin_mla() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes + 1].tolist() @@ -640,54 +623,25 @@ class AscendMLAMetadataBuilder: num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata) - # TODO: After the fullgraph supports MTP, the if branch needs to deleted - assert self.cos_cache is not None - assert self.sin_cache is not None - if cos is None and sin is None: - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=self.block_table, - seq_lens=self.seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) - else: - cos[:self.num_decode_tokens, - ...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze( - 2) - sin[:self.num_decode_tokens, - ...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze( - 2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=self.block_table, - seq_lens=self.seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin[:self.num_decode_tokens, ...], - cos=cos[:self.num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True) + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=self.block_table, + seq_lens=self.seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin[:self.num_decode_tokens, ...], + cos=cos[:self.num_decode_tokens, ...], + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) return decode_metadata def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, - model: Optional[nn.Module] = None, ): if attn_state in { AscendAttentionState.DecodeOnly, @@ -696,7 +650,6 @@ class AscendMLAMetadataBuilder: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, - model=model, ) else: raise NotImplementedError( diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 48aac26c..8c3a2226 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -12,6 +12,7 @@ from vllm.logger import logger from vllm.model_executor.layers.linear import (ReplicatedLinear, UnquantizedLinearMethod) from vllm.triton_utils import HAS_TRITON +from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend import envs @@ -107,7 +108,7 @@ class AscendSFAMetadata: M = TypeVar("M", bound=AscendSFAMetadata) -class AscendSFAMetadataBuilder: +class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE @@ -117,14 +118,17 @@ class AscendSFAMetadataBuilder: """ # _attn_mask_builder = None - def __init__(self, - kv_cache_spec, - layer_names, - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendSFAMetadata] = None): - self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ - if metadata_cls is not None else AscendSFAMetadata # type: ignore + def __init__( + self, + kv_cache_spec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendSFAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): + self.metadata_cls = (metadata_cls if metadata_cls is not None else + AscendSFAMetadata) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device @@ -142,9 +146,6 @@ class AscendSFAMetadataBuilder: got {self.decode_threshold}" self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.cos_cache = None - self.sin_cache = None - self.enable_sfa_cp = enable_sp() and \ hasattr(self.model_config.hf_config, "index_topk") @@ -163,7 +164,7 @@ class AscendSFAMetadataBuilder: self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + fast_build: bool = False, ) -> AscendSFAMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -178,33 +179,12 @@ class AscendSFAMetadataBuilder: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] has_prefill = any(query_lens_cpu > self.decode_threshold) - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - - cos, sin = get_cos_and_sin_mla() - - assert self.cos_cache is not None and self.sin_cache is not None - new_cos = self.cos_cache[input_positions][:, None, None] - new_sin = self.sin_cache[input_positions][:, None, None] - - if (cos is not None and sin is not None - and num_input_tokens <= cos.shape[0] - and num_input_tokens <= sin.shape[0]): - cos[:num_input_tokens] = new_cos - sin[:num_input_tokens] = new_sin + if has_prefill: + cos, sin = get_cos_and_sin_mla(input_positions) else: - cos, sin = new_cos, new_sin + cos, sin = get_cos_and_sin_mla(input_positions, True) sfa_cp_context = None if self.enable_sfa_cp: @@ -299,7 +279,6 @@ class AscendSFAMetadataBuilder: self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, - model: Optional[nn.Module] = None, ): if attn_state in { AscendAttentionState.DecodeOnly, @@ -308,7 +287,6 @@ class AscendSFAMetadataBuilder: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, - model=model, ) else: raise NotImplementedError( diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 12995575..63aa3e28 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -21,7 +21,6 @@ from typing import Optional, Tuple import einops import torch import torch_npu -from vllm.config import CUDAGraphMode from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) @@ -40,13 +39,15 @@ from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, # AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by # attn_metadata. This causes that rope in GQA models must pass cos && sin # by different approaches. -_cos_mla: Optional[torch.Tensor] = None -_sin_mla: Optional[torch.Tensor] = None -_cos_sin_cache: Optional[torch.Tensor] = None -_cos: Optional[torch.Tensor] = None -_sin: Optional[torch.Tensor] = None -_cos_slice: Optional[torch.Tensor] = None -_sin_slice: Optional[torch.Tensor] = None +_cos_mla: torch.Tensor = None +_sin_mla: torch.Tensor = None +_cos_cache: torch.Tensor = None +_sin_cache: torch.Tensor = None +_cos_sin_cache: torch.Tensor = None +_cos: torch.Tensor = None +_sin: torch.Tensor = None +_cos_slice: torch.Tensor = None +_sin_slice: torch.Tensor = None def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, @@ -62,25 +63,23 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, _sin is not None: return - compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens if model_config.use_mla: - if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: - rope_dim = model_config.hf_text_config.qk_rope_head_dim - _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) + rope_dim = model_config.hf_text_config.qk_rope_head_dim + _cos_mla = torch.ones(max_num_batched_tokens, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin_mla = torch.zeros(max_num_batched_tokens, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) elif not is_vl_model(vllm_config) and has_rope(vllm_config): rope_dim = model_config.get_head_size() # For models using partial rope like Qwen3-Next. @@ -101,8 +100,19 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, device=device) -def get_cos_and_sin_mla(): - return _cos_mla, _sin_mla +def get_cos_and_sin_mla(positions, use_cache=False): + global _cos_cache + global _sin_cache + cos = _cos_cache[positions].unsqueeze(1).unsqueeze(2) + sin = _sin_cache[positions].unsqueeze(1).unsqueeze(2) + if not use_cache: + return cos, sin + global _cos_mla + global _sin_mla + num_tokens = positions.size(0) + _cos_mla[:num_tokens, ...] = cos + _sin_mla[:num_tokens, ...] = sin + return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...] def _record_cos_sin_cache(cos_sin_cache): @@ -112,6 +122,13 @@ def _record_cos_sin_cache(cos_sin_cache): _cos_sin_cache = cos_sin_cache +def _record_cos_and_sin_cache(cos_cache, sin_cache): + global _cos_cache + global _sin_cache + _cos_cache = cos_cache + _sin_cache = sin_cache + + def update_cos_sin(positions): global _cos global _sin @@ -469,6 +486,8 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) + _record_cos_sin_cache(cache) + _record_cos_and_sin_cache(cos_cached, sin_cached) def forward(self, positions: torch.Tensor, diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 66dd65bd..d577304e 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -34,6 +34,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_mla_attn_dcp_pcp_params, update_mla_attn_params) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, shared_expert_dp_enabled) @@ -279,8 +280,7 @@ class MtpProposer(Proposer): builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata_mtp = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.SpecDecoding, - self.runner.get_model()) + common_attn_metadata, AscendAttentionState.SpecDecoding) attn_metadata = {} for layer_name in self.attn_layer_name: attn_metadata[layer_name] = attn_metadata_mtp @@ -945,10 +945,8 @@ class MtpProposer(Proposer): graph_pad_size - batch_size, batch_size, decode_metadata.actual_seq_lengths_q) - decode_metadata.cos = builder.cos_cache[ - positions[:batch_size]].unsqueeze(1).unsqueeze(2) - decode_metadata.sin = builder.sin_cache[ - positions[:batch_size]].unsqueeze(1).unsqueeze(2) + decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla( + positions[:batch_size]) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f0bf0ce4..053d4593 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1122,21 +1122,10 @@ class NPUModelRunner(GPUModelRunner): num_decode_draft_tokens_cpu=self. num_decode_draft_tokens.cpu[:num_reqs], ) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - elif self.model_config.runner_type == "pooling": - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - else: - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - model=self.get_model(), - **extra_attn_metadata_args) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1918,7 +1907,7 @@ class NPUModelRunner(GPUModelRunner): common_metadata) else: attn_metadata_full_attention = builder.build_for_graph_capture( - common_attn_metadata, attn_state, self.get_model()) + common_attn_metadata, attn_state) for layer_name in kv_cache_group_spec.layer_names: if "linear_attn" in layer_name: attn_metadata[