[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -289,6 +289,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@@ -296,7 +297,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -330,7 +332,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
common_metadata = MagicMock()
|
||||
model = MagicMock()
|
||||
common_metadata.graph_pad_size = 8
|
||||
common_metadata.num_reqs = 4
|
||||
common_metadata.num_actual_tokens = 5
|
||||
@@ -343,7 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
|
||||
common_metadata.block_table_tensor = block_table
|
||||
common_metadata.prefill_context_parallel_metadata = None
|
||||
metadata = builder.build(0, common_metadata, model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]),
|
||||
torch.Tensor([6, 6]))
|
||||
metadata = builder.build(0, common_metadata)
|
||||
|
||||
self.assertEqual(metadata.decode.actual_seq_lengths_q,
|
||||
[1, 2, 4, 5, 6, 6, 7, 8])
|
||||
@@ -526,6 +529,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.head_size = 128
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@@ -534,7 +538,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -579,9 +584,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
||||
torch.Tensor(10))
|
||||
metadata = builder.build(1, common_attn_metadata)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -590,6 +595,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@@ -598,7 +604,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -644,9 +651,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
||||
torch.Tensor(10))
|
||||
metadata = builder.build(1, common_attn_metadata)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -655,11 +662,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_decode_only_metadata(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
@@ -697,9 +706,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
|
||||
torch.Tensor([10, 10]))
|
||||
metadata = builder.build(1, common_attn_metadata)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -708,11 +717,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
@@ -750,10 +761,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
|
||||
torch.Tensor([10, 10]))
|
||||
metadata = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model)
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -762,11 +773,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
@@ -795,13 +808,11 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
||||
torch.Tensor(10))
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.PrefillNoCache,
|
||||
mock_model)
|
||||
common_attn_metadata, AscendAttentionState.PrefillNoCache)
|
||||
self.assertIn(
|
||||
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
|
||||
str(ctx.exception))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
assert builder.device == device
|
||||
assert builder.vllm_config == vllm_config
|
||||
|
||||
def test_ascend_sfa_metadata_builder_build(self):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
common_attn_metadata.sin = None
|
||||
common_attn_metadata.num_input_tokens = 100
|
||||
|
||||
model = MagicMock()
|
||||
model.model.layers = [MagicMock() for _ in range(10)]
|
||||
model.model.start_layer = 0
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
||||
torch.randn(100))
|
||||
|
||||
metadata = builder.build(
|
||||
common_prefix_len=10,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert isinstance(metadata, AscendSFAMetadata)
|
||||
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
||||
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
||||
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
common_attn_metadata.sin = None
|
||||
common_attn_metadata.num_input_tokens = 100
|
||||
|
||||
model = MagicMock()
|
||||
model.model.layers = [MagicMock() for _ in range(10)]
|
||||
model.model.start_layer = 0
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
||||
torch.randn(100))
|
||||
|
||||
attn_metadata = builder.build_for_graph_capture(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, AscendSFAMetadata)
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import ClassVar, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
@@ -90,7 +89,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: Optional[nn.Module] = None,
|
||||
fast_build: bool = False,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
@@ -20,7 +20,6 @@ from enum import Enum
|
||||
from typing import ClassVar, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
@@ -29,7 +28,8 @@ from vllm.attention.backends.registry import (AttentionBackendEnum,
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@@ -170,7 +170,7 @@ class AscendMetadata:
|
||||
model_runner_type: str = ""
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
@@ -217,8 +217,8 @@ class AscendAttentionMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
fast_build: bool = False,
|
||||
) -> AscendMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
@@ -261,7 +261,6 @@ class AscendAttentionMetadataBuilder:
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
attn_metadata = self.build(
|
||||
|
||||
@@ -4,7 +4,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
@@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendMLAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
metadata_cls)
|
||||
metadata_cls, supports_dcp_with_varlen)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
@@ -92,7 +94,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendPCPMetadata | None:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert common_long_seq_metadata is not None
|
||||
@@ -121,10 +122,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
chunked_context_metadata = super().build_chunked_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
if chunked_context_metadata is None:
|
||||
return None
|
||||
|
||||
@@ -205,12 +205,11 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLAPrefillMetadata:
|
||||
prefill_metadata = super().build_prefill_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
prefill_metadata.pcp_metadata = self.build_cp_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
prefill_metadata.block_table = self.block_table[
|
||||
self.num_decodes_flatten:, ...]
|
||||
return prefill_metadata
|
||||
@@ -219,10 +218,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLADecodeMetadata:
|
||||
decode_metadata = super().build_decode_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert long_seq_metadata is not None
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
@@ -13,6 +12,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
@@ -177,7 +177,7 @@ class AscendMLAMetadata:
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMLAMetadataBuilder:
|
||||
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
@@ -186,14 +186,17 @@ class AscendMLAMetadataBuilder:
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendMLAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
||||
AscendMLAMetadata)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
@@ -384,7 +387,7 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
fast_build: bool = False,
|
||||
) -> AscendMLAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
@@ -400,17 +403,6 @@ class AscendMLAMetadataBuilder:
|
||||
self.slot_mapping = common_attn_metadata.slot_mapping[:self.
|
||||
num_actual_tokens]
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
@@ -420,12 +412,12 @@ class AscendMLAMetadataBuilder:
|
||||
prefill_metadata = None
|
||||
if self.num_prefills > 0:
|
||||
prefill_metadata = self.build_prefill_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
decode_metadata = None
|
||||
if self.num_decodes > 0:
|
||||
decode_metadata = self.build_decode_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
||||
@@ -450,7 +442,6 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
if not self.chunked_prefill_enabled:
|
||||
return None
|
||||
@@ -520,7 +511,6 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLAPrefillMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
@@ -530,7 +520,7 @@ class AscendMLAMetadataBuilder:
|
||||
)
|
||||
|
||||
chunked_context_metadata = self.build_chunked_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
reqs_start = self.num_decodes # prefill_start
|
||||
tokens_start = self.num_decode_tokens
|
||||
max_query_len = self.query_lens[reqs_start:].max().item()
|
||||
@@ -539,12 +529,7 @@ class AscendMLAMetadataBuilder:
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||
return AscendMLAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||
@@ -564,7 +549,6 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLADecodeMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
@@ -573,7 +557,6 @@ class AscendMLAMetadataBuilder:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
cos, sin = get_cos_and_sin_mla()
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes +
|
||||
1].tolist()
|
||||
@@ -640,54 +623,25 @@ class AscendMLAMetadataBuilder:
|
||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
|
||||
common_attn_metadata)
|
||||
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
assert self.sin_cache is not None
|
||||
if cos is None and sin is None:
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=self.block_table,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
else:
|
||||
cos[:self.num_decode_tokens,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(
|
||||
2)
|
||||
sin[:self.num_decode_tokens,
|
||||
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(
|
||||
2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=self.block_table,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:self.num_decode_tokens, ...],
|
||||
cos=cos[:self.num_decode_tokens, ...],
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True)
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=self.block_table,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:self.num_decode_tokens, ...],
|
||||
cos=cos[:self.num_decode_tokens, ...],
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
return decode_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
@@ -696,7 +650,6 @@ class AscendMLAMetadataBuilder:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend import envs
|
||||
@@ -107,7 +108,7 @@ class AscendSFAMetadata:
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFAMetadataBuilder:
|
||||
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
@@ -117,14 +118,17 @@ class AscendSFAMetadataBuilder:
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendSFAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendSFAMetadata # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendSFAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
||||
AscendSFAMetadata)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
@@ -142,9 +146,6 @@ class AscendSFAMetadataBuilder:
|
||||
got {self.decode_threshold}"
|
||||
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
self.enable_sfa_cp = enable_sp() and \
|
||||
hasattr(self.model_config.hf_config, "index_topk")
|
||||
|
||||
@@ -163,7 +164,7 @@ class AscendSFAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
fast_build: bool = False,
|
||||
) -> AscendSFAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
@@ -178,33 +179,12 @@ class AscendSFAMetadataBuilder:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
has_prefill = any(query_lens_cpu > self.decode_threshold)
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||
|
||||
cos, sin = get_cos_and_sin_mla()
|
||||
|
||||
assert self.cos_cache is not None and self.sin_cache is not None
|
||||
new_cos = self.cos_cache[input_positions][:, None, None]
|
||||
new_sin = self.sin_cache[input_positions][:, None, None]
|
||||
|
||||
if (cos is not None and sin is not None
|
||||
and num_input_tokens <= cos.shape[0]
|
||||
and num_input_tokens <= sin.shape[0]):
|
||||
cos[:num_input_tokens] = new_cos
|
||||
sin[:num_input_tokens] = new_sin
|
||||
if has_prefill:
|
||||
cos, sin = get_cos_and_sin_mla(input_positions)
|
||||
else:
|
||||
cos, sin = new_cos, new_sin
|
||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||
|
||||
sfa_cp_context = None
|
||||
if self.enable_sfa_cp:
|
||||
@@ -299,7 +279,6 @@ class AscendSFAMetadataBuilder:
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
@@ -308,7 +287,6 @@ class AscendSFAMetadataBuilder:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -21,7 +21,6 @@ from typing import Optional, Tuple
|
||||
import einops
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
@@ -40,13 +39,15 @@ from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
|
||||
# attn_metadata. This causes that rope in GQA models must pass cos && sin
|
||||
# by different approaches.
|
||||
_cos_mla: Optional[torch.Tensor] = None
|
||||
_sin_mla: Optional[torch.Tensor] = None
|
||||
_cos_sin_cache: Optional[torch.Tensor] = None
|
||||
_cos: Optional[torch.Tensor] = None
|
||||
_sin: Optional[torch.Tensor] = None
|
||||
_cos_slice: Optional[torch.Tensor] = None
|
||||
_sin_slice: Optional[torch.Tensor] = None
|
||||
_cos_mla: torch.Tensor = None
|
||||
_sin_mla: torch.Tensor = None
|
||||
_cos_cache: torch.Tensor = None
|
||||
_sin_cache: torch.Tensor = None
|
||||
_cos_sin_cache: torch.Tensor = None
|
||||
_cos: torch.Tensor = None
|
||||
_sin: torch.Tensor = None
|
||||
_cos_slice: torch.Tensor = None
|
||||
_sin_slice: torch.Tensor = None
|
||||
|
||||
|
||||
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
@@ -62,25 +63,23 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
_sin is not None:
|
||||
return
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
if model_config.use_mla:
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
_cos_mla = torch.ones(max_num_batched_tokens,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin_mla = torch.zeros(max_num_batched_tokens,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
elif not is_vl_model(vllm_config) and has_rope(vllm_config):
|
||||
rope_dim = model_config.get_head_size()
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
@@ -101,8 +100,19 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
device=device)
|
||||
|
||||
|
||||
def get_cos_and_sin_mla():
|
||||
return _cos_mla, _sin_mla
|
||||
def get_cos_and_sin_mla(positions, use_cache=False):
|
||||
global _cos_cache
|
||||
global _sin_cache
|
||||
cos = _cos_cache[positions].unsqueeze(1).unsqueeze(2)
|
||||
sin = _sin_cache[positions].unsqueeze(1).unsqueeze(2)
|
||||
if not use_cache:
|
||||
return cos, sin
|
||||
global _cos_mla
|
||||
global _sin_mla
|
||||
num_tokens = positions.size(0)
|
||||
_cos_mla[:num_tokens, ...] = cos
|
||||
_sin_mla[:num_tokens, ...] = sin
|
||||
return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...]
|
||||
|
||||
|
||||
def _record_cos_sin_cache(cos_sin_cache):
|
||||
@@ -112,6 +122,13 @@ def _record_cos_sin_cache(cos_sin_cache):
|
||||
_cos_sin_cache = cos_sin_cache
|
||||
|
||||
|
||||
def _record_cos_and_sin_cache(cos_cache, sin_cache):
|
||||
global _cos_cache
|
||||
global _sin_cache
|
||||
_cos_cache = cos_cache
|
||||
_sin_cache = sin_cache
|
||||
|
||||
|
||||
def update_cos_sin(positions):
|
||||
global _cos
|
||||
global _sin
|
||||
@@ -469,6 +486,8 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
_record_cos_sin_cache(cache)
|
||||
_record_cos_and_sin_cache(cos_cached, sin_cached)
|
||||
|
||||
def forward(self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@@ -34,6 +34,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
shared_expert_dp_enabled)
|
||||
@@ -279,8 +280,7 @@ class MtpProposer(Proposer):
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.SpecDecoding,
|
||||
self.runner.get_model())
|
||||
common_attn_metadata, AscendAttentionState.SpecDecoding)
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
@@ -945,10 +945,8 @@ class MtpProposer(Proposer):
|
||||
graph_pad_size - batch_size,
|
||||
batch_size,
|
||||
decode_metadata.actual_seq_lengths_q)
|
||||
decode_metadata.cos = builder.cos_cache[
|
||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||
decode_metadata.sin = builder.sin_cache[
|
||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||
decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(
|
||||
positions[:batch_size])
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
|
||||
@@ -1122,21 +1122,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_decode_draft_tokens_cpu=self.
|
||||
num_decode_draft_tokens.cpu[:num_reqs],
|
||||
)
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
elif self.model_config.runner_type == "pooling":
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=self.get_model(),
|
||||
**extra_attn_metadata_args)
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@@ -1918,7 +1907,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
common_metadata)
|
||||
else:
|
||||
attn_metadata_full_attention = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
common_attn_metadata, attn_state)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
if "linear_attn" in layer_name:
|
||||
attn_metadata[
|
||||
|
||||
Reference in New Issue
Block a user