[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.



version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-28 10:35:07 +08:00
committed by GitHub
parent 24328aaf00
commit dbe4c338f2
10 changed files with 167 additions and 224 deletions

View File

@@ -289,6 +289,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled, builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill) mock_vllm_config.scheduler_config.enable_chunked_prefill)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP', @patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator)) new_callable=lambda: MagicMock(spec=GroupCoordinator))
@@ -296,7 +297,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state._DCP', @patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator)) new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_ascend_mla_metadata_builder_build_full_graph( def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group): self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_vllm_config = MagicMock() mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -330,7 +332,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device) mock_device)
common_metadata = MagicMock() common_metadata = MagicMock()
model = MagicMock()
common_metadata.graph_pad_size = 8 common_metadata.graph_pad_size = 8
common_metadata.num_reqs = 4 common_metadata.num_reqs = 4
common_metadata.num_actual_tokens = 5 common_metadata.num_actual_tokens = 5
@@ -343,7 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int() block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
common_metadata.block_table_tensor = block_table common_metadata.block_table_tensor = block_table
common_metadata.prefill_context_parallel_metadata = None common_metadata.prefill_context_parallel_metadata = None
metadata = builder.build(0, common_metadata, model) mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]),
torch.Tensor([6, 6]))
metadata = builder.build(0, common_metadata)
self.assertEqual(metadata.decode.actual_seq_lengths_q, self.assertEqual(metadata.decode.actual_seq_lengths_q,
[1, 2, 4, 5, 6, 6, 7, 8]) [1, 2, 4, 5, 6, 6, 7, 8])
@@ -526,6 +529,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.kv_cache_spec.head_size = 128 self.kv_cache_spec.head_size = 128
self.kv_cache_spec.num_heads = 32 self.kv_cache_spec.num_heads = 32
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size", @patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1) return_value=1)
@@ -534,7 +538,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
@patch("torch.npu.is_available") @patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available, def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size, mock_zeros, mock_dcp_world_size,
mock_get_pcp_group): mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1 mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa torch.Tensor.pin_memory = lambda x: x # noqa
@@ -579,9 +584,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"], layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config, vllm_config=self.mock_vllm_config,
device=self.mock_device) device=self.mock_device)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
mock_model = MagicMock() torch.Tensor(10))
metadata = builder.build(1, common_attn_metadata, mock_model) metadata = builder.build(1, common_attn_metadata)
self.assertIsInstance(metadata, AscendMLAMetadata) self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens, self.assertEqual(metadata.num_actual_tokens,
@@ -590,6 +595,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size", @patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1) return_value=1)
@@ -598,7 +604,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
@patch("torch.npu.is_available") @patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available, def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size, mock_zeros, mock_dcp_world_size,
mock_get_pcp_group): mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1 mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa torch.Tensor.pin_memory = lambda x: x # noqa
@@ -644,9 +651,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"], layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config, vllm_config=self.mock_vllm_config,
device=self.mock_device) device=self.mock_device)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
mock_model = MagicMock() torch.Tensor(10))
metadata = builder.build(1, common_attn_metadata, mock_model) metadata = builder.build(1, common_attn_metadata)
self.assertIsInstance(metadata, AscendMLAMetadata) self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens, self.assertEqual(metadata.num_actual_tokens,
@@ -655,11 +662,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size", @patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1) return_value=1)
def test_build_decode_only_metadata(self, mock_dcp_world_size, def test_build_decode_only_metadata(self, mock_dcp_world_size,
mock_get_pcp_group): mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1 mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa torch.Tensor.pin_memory = lambda x: x # noqa
@@ -697,9 +706,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"], layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config, vllm_config=self.mock_vllm_config,
device=self.mock_device) device=self.mock_device)
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
mock_model = MagicMock() torch.Tensor([10, 10]))
metadata = builder.build(1, common_attn_metadata, mock_model) metadata = builder.build(1, common_attn_metadata)
self.assertIsInstance(metadata, AscendMLAMetadata) self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens, self.assertEqual(metadata.num_actual_tokens,
@@ -708,11 +717,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size", @patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1) return_value=1)
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
mock_get_pcp_group): mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1 mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa torch.Tensor.pin_memory = lambda x: x # noqa
@@ -750,10 +761,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"], layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config, vllm_config=self.mock_vllm_config,
device=self.mock_device) device=self.mock_device)
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
mock_model = MagicMock() torch.Tensor([10, 10]))
metadata = builder.build_for_graph_capture( metadata = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model) common_attn_metadata, AscendAttentionState.DecodeOnly)
self.assertIsInstance(metadata, AscendMLAMetadata) self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens, self.assertEqual(metadata.num_actual_tokens,
@@ -762,11 +773,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size", @patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1) return_value=1)
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
mock_get_pcp_group): mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1 mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator) pcp_group = MagicMock(spec=GroupCoordinator)
@@ -795,13 +808,11 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"], layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config, vllm_config=self.mock_vllm_config,
device=self.mock_device) device=self.mock_device)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
mock_model = MagicMock() torch.Tensor(10))
with self.assertRaises(NotImplementedError) as ctx: with self.assertRaises(NotImplementedError) as ctx:
builder.build_for_graph_capture( builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.PrefillNoCache, common_attn_metadata, AscendAttentionState.PrefillNoCache)
mock_model)
self.assertIn( self.assertIn(
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state", "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
str(ctx.exception)) str(ctx.exception))

View File

@@ -1,5 +1,5 @@
import sys import sys
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import torch import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
@@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase):
assert builder.device == device assert builder.device == device
assert builder.vllm_config == vllm_config assert builder.vllm_config == vllm_config
def test_ascend_sfa_metadata_builder_build(self): @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock() kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"] layer_names = ["layer1", "layer2"]
vllm_config = MagicMock() vllm_config = MagicMock()
@@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase):
common_attn_metadata.sin = None common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100 common_attn_metadata.num_input_tokens = 100
model = MagicMock() mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
model.model.layers = [MagicMock() for _ in range(10)] torch.randn(100))
model.model.start_layer = 0
metadata = builder.build( metadata = builder.build(
common_prefix_len=10, common_prefix_len=10,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
model=model,
) )
assert isinstance(metadata, AscendSFAMetadata) assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024) assert metadata.slot_mapping.shape == (100, 4, 1024)
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self): @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock() kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"] layer_names = ["layer1", "layer2"]
vllm_config = MagicMock() vllm_config = MagicMock()
@@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase):
common_attn_metadata.sin = None common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100 common_attn_metadata.num_input_tokens = 100
model = MagicMock() mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
model.model.layers = [MagicMock() for _ in range(10)] torch.randn(100))
model.model.start_layer = 0
attn_metadata = builder.build_for_graph_capture( attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly, attn_state=AscendAttentionState.DecodeOnly,
model=model,
) )
assert isinstance(attn_metadata, AscendSFAMetadata) assert isinstance(attn_metadata, AscendSFAMetadata)

View File

@@ -20,7 +20,6 @@ from typing import ClassVar, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
import torch_npu import torch_npu
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_dcp_group,
@@ -90,7 +89,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None, fast_build: bool = False,
): ):
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens

View File

@@ -20,7 +20,6 @@ from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type from typing import ClassVar, List, Optional, Tuple, Type
import torch import torch
import torch.nn as nn
import torch_npu import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
@@ -29,7 +28,8 @@ from vllm.attention.backends.registry import (AttentionBackendEnum,
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder)
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@@ -170,7 +170,7 @@ class AscendMetadata:
model_runner_type: str = "" model_runner_type: str = ""
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no). # Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \ aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS AttentionCGSupport.ALWAYS
@@ -217,8 +217,8 @@ class AscendAttentionMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None, fast_build: bool = False,
): ) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -261,7 +261,6 @@ class AscendAttentionMetadataBuilder:
self, self,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
): ):
if attn_state == AscendAttentionState.DecodeOnly: if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build( attn_metadata = self.build(

View File

@@ -4,7 +4,6 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank, get_decode_context_model_parallel_rank,
@@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
understand this class understand this class
""" """
def __init__(self, def __init__(
kv_cache_spec: MLAAttentionSpec, self,
layer_names: list[str], kv_cache_spec: MLAAttentionSpec,
vllm_config: VllmConfig, layer_names: list[str],
device: torch.device, vllm_config: VllmConfig,
metadata_cls: Optional[AscendMLAMetadata] = None): device: torch.device,
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device, super().__init__(kv_cache_spec, layer_names, vllm_config, device,
metadata_cls) metadata_cls, supports_dcp_with_varlen)
self.pcp_size = get_pcp_group().world_size self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group( self.pcp_rank = get_pcp_group(
@@ -92,7 +94,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendPCPMetadata | None: ) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None assert common_long_seq_metadata is not None
@@ -121,10 +122,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
): ):
chunked_context_metadata = super().build_chunked_metadata( chunked_context_metadata = super().build_chunked_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
if chunked_context_metadata is None: if chunked_context_metadata is None:
return None return None
@@ -205,12 +205,11 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata: ) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata( prefill_metadata = super().build_prefill_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata( prefill_metadata.pcp_metadata = self.build_cp_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[ prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...] self.num_decodes_flatten:, ...]
return prefill_metadata return prefill_metadata
@@ -219,10 +218,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata: ) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata( decode_metadata = super().build_decode_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None assert long_seq_metadata is not None

View File

@@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
import numpy as np import numpy as np
import torch import torch
import torch_npu import torch_npu
from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
@@ -13,6 +12,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm.v1.kv_cache_interface import MLAAttentionSpec
@@ -177,7 +177,7 @@ class AscendMLAMetadata:
M = TypeVar("M", bound=AscendMLAMetadata) M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder: class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no). # Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \ aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH AttentionCGSupport.UNIFORM_BATCH
@@ -186,14 +186,17 @@ class AscendMLAMetadataBuilder:
understand this class understand this class
""" """
def __init__(self, def __init__(
kv_cache_spec: MLAAttentionSpec, self,
layer_names: list[str], kv_cache_spec: MLAAttentionSpec,
vllm_config: VllmConfig, layer_names: list[str],
device: torch.device, vllm_config: VllmConfig,
metadata_cls: Optional[AscendMLAMetadata] = None): device: torch.device,
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ metadata_cls: type[AscendMLAMetadata] | None = None,
if metadata_cls is not None else AscendMLAMetadata # type: ignore supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendMLAMetadata)
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.device = device self.device = device
@@ -384,7 +387,7 @@ class AscendMLAMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module, fast_build: bool = False,
) -> AscendMLAMetadata: ) -> AscendMLAMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
@@ -400,17 +403,6 @@ class AscendMLAMetadataBuilder:
self.slot_mapping = common_attn_metadata.slot_mapping[:self. self.slot_mapping = common_attn_metadata.slot_mapping[:self.
num_actual_tokens] num_actual_tokens]
if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.query_lens = query_seq_lens_cpu[:num_reqs] self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
@@ -420,12 +412,12 @@ class AscendMLAMetadataBuilder:
prefill_metadata = None prefill_metadata = None
if self.num_prefills > 0: if self.num_prefills > 0:
prefill_metadata = self.build_prefill_metadata( prefill_metadata = self.build_prefill_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
decode_metadata = None decode_metadata = None
if self.num_decodes > 0: if self.num_decodes > 0:
decode_metadata = self.build_decode_metadata( decode_metadata = self.build_decode_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=self.num_actual_tokens, num_actual_tokens_pcp_padded=self.num_actual_tokens,
@@ -450,7 +442,6 @@ class AscendMLAMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
): ):
if not self.chunked_prefill_enabled: if not self.chunked_prefill_enabled:
return None return None
@@ -520,7 +511,6 @@ class AscendMLAMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata: ) -> AscendMLAPrefillMetadata:
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
@@ -530,7 +520,7 @@ class AscendMLAMetadataBuilder:
) )
chunked_context_metadata = self.build_chunked_metadata( chunked_context_metadata = self.build_chunked_metadata(
common_prefix_len, common_attn_metadata, model) common_prefix_len, common_attn_metadata)
reqs_start = self.num_decodes # prefill_start reqs_start = self.num_decodes # prefill_start
tokens_start = self.num_decode_tokens tokens_start = self.num_decode_tokens
max_query_len = self.query_lens[reqs_start:].max().item() max_query_len = self.query_lens[reqs_start:].max().item()
@@ -539,12 +529,7 @@ class AscendMLAMetadataBuilder:
reqs_start:] - query_start_loc[reqs_start] reqs_start:] - query_start_loc[reqs_start]
prefill_input_positions = input_positions[tokens_start:] prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[ cos, sin = get_cos_and_sin_mla(prefill_input_positions)
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
return AscendMLAPrefillMetadata( return AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask, attn_mask=common_attn_metadata.attn_mask,
query_lens=self.query_lens[reqs_start:].to(torch.int32), query_lens=self.query_lens[reqs_start:].to(torch.int32),
@@ -564,7 +549,6 @@ class AscendMLAMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata: ) -> AscendMLADecodeMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
@@ -573,7 +557,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens].long( num_actual_tokens].long(
) )
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes + actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes +
1].tolist() 1].tolist()
@@ -640,54 +623,25 @@ class AscendMLAMetadataBuilder:
num_reqs_pad_size, num_reqs, actual_seq_lengths_q, num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata) common_attn_metadata)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True)
assert self.cos_cache is not None decode_metadata = AscendMLADecodeMetadata(
assert self.sin_cache is not None input_positions=input_positions,
if cos is None and sin is None: block_table=self.block_table,
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore seq_lens=self.seq_lens,
1).unsqueeze(2) seq_lens_list=seq_lens_list,
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore max_seq_lens=max_seq_lens,
1).unsqueeze(2) attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
decode_metadata = AscendMLADecodeMetadata( sin=sin[:self.num_decode_tokens, ...],
input_positions=input_positions, cos=cos[:self.num_decode_tokens, ...],
block_table=self.block_table, cp_seq_len=cp_seq_len,
seq_lens=self.seq_lens, batch_seq_mask=batch_seq_mask)
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:self.num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(
2)
sin[:self.num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(
2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
return decode_metadata return decode_metadata
def build_for_graph_capture( def build_for_graph_capture(
self, self,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
): ):
if attn_state in { if attn_state in {
AscendAttentionState.DecodeOnly, AscendAttentionState.DecodeOnly,
@@ -696,7 +650,6 @@ class AscendMLAMetadataBuilder:
attn_metadata = self.build( attn_metadata = self.build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
model=model,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(

View File

@@ -12,6 +12,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.linear import (ReplicatedLinear, from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs from vllm_ascend import envs
@@ -107,7 +108,7 @@ class AscendSFAMetadata:
M = TypeVar("M", bound=AscendSFAMetadata) M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder: class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no). # Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \ aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
@@ -117,14 +118,17 @@ class AscendSFAMetadataBuilder:
""" """
# _attn_mask_builder = None # _attn_mask_builder = None
def __init__(self, def __init__(
kv_cache_spec, self,
layer_names, kv_cache_spec,
vllm_config: VllmConfig, layer_names: list[str],
device: torch.device, vllm_config: VllmConfig,
metadata_cls: Optional[AscendSFAMetadata] = None): device: torch.device,
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ metadata_cls: type[AscendSFAMetadata] | None = None,
if metadata_cls is not None else AscendSFAMetadata # type: ignore supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendSFAMetadata)
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.device = device self.device = device
@@ -142,9 +146,6 @@ class AscendSFAMetadataBuilder:
got {self.decode_threshold}" got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.enable_sfa_cp = enable_sp() and \ self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_config, "index_topk") hasattr(self.model_config.hf_config, "index_topk")
@@ -163,7 +164,7 @@ class AscendSFAMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module, fast_build: bool = False,
) -> AscendSFAMetadata: ) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -178,33 +179,12 @@ class AscendSFAMetadataBuilder:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
has_prefill = any(query_lens_cpu > self.decode_threshold) has_prefill = any(query_lens_cpu > self.decode_threshold)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs] seq_lens = common_attn_metadata.seq_lens[:num_reqs]
if has_prefill:
cos, sin = get_cos_and_sin_mla() cos, sin = get_cos_and_sin_mla(input_positions)
assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]
new_sin = self.sin_cache[input_positions][:, None, None]
if (cos is not None and sin is not None
and num_input_tokens <= cos.shape[0]
and num_input_tokens <= sin.shape[0]):
cos[:num_input_tokens] = new_cos
sin[:num_input_tokens] = new_sin
else: else:
cos, sin = new_cos, new_sin cos, sin = get_cos_and_sin_mla(input_positions, True)
sfa_cp_context = None sfa_cp_context = None
if self.enable_sfa_cp: if self.enable_sfa_cp:
@@ -299,7 +279,6 @@ class AscendSFAMetadataBuilder:
self, self,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
): ):
if attn_state in { if attn_state in {
AscendAttentionState.DecodeOnly, AscendAttentionState.DecodeOnly,
@@ -308,7 +287,6 @@ class AscendSFAMetadataBuilder:
attn_metadata = self.build( attn_metadata = self.build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
model=model,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(

View File

@@ -21,7 +21,6 @@ from typing import Optional, Tuple
import einops import einops
import torch import torch
import torch_npu import torch_npu
from vllm.config import CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding) YaRNScalingRotaryEmbedding)
@@ -40,13 +39,15 @@ from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by # AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
# attn_metadata. This causes that rope in GQA models must pass cos && sin # attn_metadata. This causes that rope in GQA models must pass cos && sin
# by different approaches. # by different approaches.
_cos_mla: Optional[torch.Tensor] = None _cos_mla: torch.Tensor = None
_sin_mla: Optional[torch.Tensor] = None _sin_mla: torch.Tensor = None
_cos_sin_cache: Optional[torch.Tensor] = None _cos_cache: torch.Tensor = None
_cos: Optional[torch.Tensor] = None _sin_cache: torch.Tensor = None
_sin: Optional[torch.Tensor] = None _cos_sin_cache: torch.Tensor = None
_cos_slice: Optional[torch.Tensor] = None _cos: torch.Tensor = None
_sin_slice: Optional[torch.Tensor] = None _sin: torch.Tensor = None
_cos_slice: torch.Tensor = None
_sin_slice: torch.Tensor = None
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
@@ -62,25 +63,23 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
_sin is not None: _sin is not None:
return return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
if model_config.use_mla: if model_config.use_mla:
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: rope_dim = model_config.hf_text_config.qk_rope_head_dim
rope_dim = model_config.hf_text_config.qk_rope_head_dim _cos_mla = torch.ones(max_num_batched_tokens,
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req, 1,
1, 1,
1, rope_dim,
rope_dim, dtype=dtype,
dtype=dtype, device=device)
device=device) _sin_mla = torch.zeros(max_num_batched_tokens,
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, 1,
1, 1,
1, rope_dim,
rope_dim, dtype=dtype,
dtype=dtype, device=device)
device=device)
elif not is_vl_model(vllm_config) and has_rope(vllm_config): elif not is_vl_model(vllm_config) and has_rope(vllm_config):
rope_dim = model_config.get_head_size() rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next. # For models using partial rope like Qwen3-Next.
@@ -101,8 +100,19 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device=device) device=device)
def get_cos_and_sin_mla(): def get_cos_and_sin_mla(positions, use_cache=False):
return _cos_mla, _sin_mla global _cos_cache
global _sin_cache
cos = _cos_cache[positions].unsqueeze(1).unsqueeze(2)
sin = _sin_cache[positions].unsqueeze(1).unsqueeze(2)
if not use_cache:
return cos, sin
global _cos_mla
global _sin_mla
num_tokens = positions.size(0)
_cos_mla[:num_tokens, ...] = cos
_sin_mla[:num_tokens, ...] = sin
return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...]
def _record_cos_sin_cache(cos_sin_cache): def _record_cos_sin_cache(cos_sin_cache):
@@ -112,6 +122,13 @@ def _record_cos_sin_cache(cos_sin_cache):
_cos_sin_cache = cos_sin_cache _cos_sin_cache = cos_sin_cache
def _record_cos_and_sin_cache(cos_cache, sin_cache):
global _cos_cache
global _sin_cache
_cos_cache = cos_cache
_sin_cache = sin_cache
def update_cos_sin(positions): def update_cos_sin(positions):
global _cos global _cos
global _sin global _sin
@@ -469,6 +486,8 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
_record_cos_sin_cache(cache)
_record_cos_and_sin_cache(cos_cached, sin_cached)
def forward(self, def forward(self,
positions: torch.Tensor, positions: torch.Tensor,

View File

@@ -34,6 +34,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_mla_attn_dcp_pcp_params, update_mla_attn_dcp_pcp_params,
update_mla_attn_params) update_mla_attn_params)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
shared_expert_dp_enabled) shared_expert_dp_enabled)
@@ -279,8 +280,7 @@ class MtpProposer(Proposer):
builder = self.runner.attn_groups[0][0].get_metadata_builder() builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build_for_graph_capture( attn_metadata_mtp = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.SpecDecoding, common_attn_metadata, AscendAttentionState.SpecDecoding)
self.runner.get_model())
attn_metadata = {} attn_metadata = {}
for layer_name in self.attn_layer_name: for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp attn_metadata[layer_name] = attn_metadata_mtp
@@ -945,10 +945,8 @@ class MtpProposer(Proposer):
graph_pad_size - batch_size, graph_pad_size - batch_size,
batch_size, batch_size,
decode_metadata.actual_seq_lengths_q) decode_metadata.actual_seq_lengths_q)
decode_metadata.cos = builder.cos_cache[ decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(
positions[:batch_size]].unsqueeze(1).unsqueeze(2) positions[:batch_size])
decode_metadata.sin = builder.sin_cache[
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
# NOTE(woosuk): We should handle the case where the draft model # NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex # generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch # to remove such requests from the batch, we keep them in the batch

View File

@@ -1122,21 +1122,10 @@ class NPUModelRunner(GPUModelRunner):
num_decode_draft_tokens_cpu=self. num_decode_draft_tokens_cpu=self.
num_decode_draft_tokens.cpu[:num_reqs], num_decode_draft_tokens.cpu[:num_reqs],
) )
attn_metadata_i = builder.build( attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args) **extra_attn_metadata_args)
elif self.model_config.runner_type == "pooling":
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
else:
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
model=self.get_model(),
**extra_attn_metadata_args)
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
@@ -1918,7 +1907,7 @@ class NPUModelRunner(GPUModelRunner):
common_metadata) common_metadata)
else: else:
attn_metadata_full_attention = builder.build_for_graph_capture( attn_metadata_full_attention = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model()) common_attn_metadata, attn_state)
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
if "linear_attn" in layer_name: if "linear_attn" in layer_name:
attn_metadata[ attn_metadata[