[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.



version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-28 10:35:07 +08:00
committed by GitHub
parent 24328aaf00
commit dbe4c338f2
10 changed files with 167 additions and 224 deletions

View File

@@ -20,7 +20,6 @@ from typing import ClassVar, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
@@ -90,7 +89,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None,
fast_build: bool = False,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens

View File

@@ -20,7 +20,6 @@ from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
@@ -29,7 +28,8 @@ from vllm.attention.backends.registry import (AttentionBackendEnum,
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -170,7 +170,7 @@ class AscendMetadata:
model_runner_type: str = ""
class AscendAttentionMetadataBuilder:
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
@@ -217,8 +217,8 @@ class AscendAttentionMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None,
):
fast_build: bool = False,
) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -261,7 +261,6 @@ class AscendAttentionMetadataBuilder:
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(

View File

@@ -4,7 +4,6 @@ import numpy as np
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
@@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
understand this class
"""
def __init__(self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
def __init__(
self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
metadata_cls)
metadata_cls, supports_dcp_with_varlen)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
@@ -92,7 +94,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
@@ -121,10 +122,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
chunked_context_metadata = super().build_chunked_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
if chunked_context_metadata is None:
return None
@@ -205,12 +205,11 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...]
return prefill_metadata
@@ -219,10 +218,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None

View File

@@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
import numpy as np
import torch
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config
@@ -13,6 +12,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@@ -177,7 +177,7 @@ class AscendMLAMetadata:
M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder:
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
@@ -186,14 +186,17 @@ class AscendMLAMetadataBuilder:
understand this class
"""
def __init__(self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
def __init__(
self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendMLAMetadata)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
@@ -384,7 +387,7 @@ class AscendMLAMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
fast_build: bool = False,
) -> AscendMLAMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
@@ -400,17 +403,6 @@ class AscendMLAMetadataBuilder:
self.slot_mapping = common_attn_metadata.slot_mapping[:self.
num_actual_tokens]
if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
@@ -420,12 +412,12 @@ class AscendMLAMetadataBuilder:
prefill_metadata = None
if self.num_prefills > 0:
prefill_metadata = self.build_prefill_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
decode_metadata = None
if self.num_decodes > 0:
decode_metadata = self.build_decode_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=self.num_actual_tokens,
@@ -450,7 +442,6 @@ class AscendMLAMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
if not self.chunked_prefill_enabled:
return None
@@ -520,7 +511,6 @@ class AscendMLAMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata:
query_start_loc = common_attn_metadata.query_start_loc
@@ -530,7 +520,7 @@ class AscendMLAMetadataBuilder:
)
chunked_context_metadata = self.build_chunked_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
reqs_start = self.num_decodes # prefill_start
tokens_start = self.num_decode_tokens
max_query_len = self.query_lens[reqs_start:].max().item()
@@ -539,12 +529,7 @@ class AscendMLAMetadataBuilder:
reqs_start:] - query_start_loc[reqs_start]
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
return AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=self.query_lens[reqs_start:].to(torch.int32),
@@ -564,7 +549,6 @@ class AscendMLAMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
@@ -573,7 +557,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens].long(
)
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes +
1].tolist()
@@ -640,54 +623,25 @@ class AscendMLAMetadataBuilder:
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is None:
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:self.num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(
2)
sin[:self.num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(
2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
return decode_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state in {
AscendAttentionState.DecodeOnly,
@@ -696,7 +650,6 @@ class AscendMLAMetadataBuilder:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
model=model,
)
else:
raise NotImplementedError(

View File

@@ -12,6 +12,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
@@ -107,7 +108,7 @@ class AscendSFAMetadata:
M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder:
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
@@ -117,14 +118,17 @@ class AscendSFAMetadataBuilder:
"""
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendSFAMetadata] = None):
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
if metadata_cls is not None else AscendSFAMetadata # type: ignore
def __init__(
self,
kv_cache_spec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendSFAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendSFAMetadata)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
@@ -142,9 +146,6 @@ class AscendSFAMetadataBuilder:
got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_config, "index_topk")
@@ -163,7 +164,7 @@ class AscendSFAMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
fast_build: bool = False,
) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -178,33 +179,12 @@ class AscendSFAMetadataBuilder:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
has_prefill = any(query_lens_cpu > self.decode_threshold)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos, sin = get_cos_and_sin_mla()
assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]
new_sin = self.sin_cache[input_positions][:, None, None]
if (cos is not None and sin is not None
and num_input_tokens <= cos.shape[0]
and num_input_tokens <= sin.shape[0]):
cos[:num_input_tokens] = new_cos
sin[:num_input_tokens] = new_sin
if has_prefill:
cos, sin = get_cos_and_sin_mla(input_positions)
else:
cos, sin = new_cos, new_sin
cos, sin = get_cos_and_sin_mla(input_positions, True)
sfa_cp_context = None
if self.enable_sfa_cp:
@@ -299,7 +279,6 @@ class AscendSFAMetadataBuilder:
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state in {
AscendAttentionState.DecodeOnly,
@@ -308,7 +287,6 @@ class AscendSFAMetadataBuilder:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
model=model,
)
else:
raise NotImplementedError(