[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.



version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-28 10:35:07 +08:00
committed by GitHub
parent 24328aaf00
commit dbe4c338f2
10 changed files with 167 additions and 224 deletions

View File

@@ -4,7 +4,6 @@ import numpy as np
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
@@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
understand this class
"""
def __init__(self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
def __init__(
self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
metadata_cls)
metadata_cls, supports_dcp_with_varlen)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
@@ -92,7 +94,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
@@ -121,10 +122,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
chunked_context_metadata = super().build_chunked_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
if chunked_context_metadata is None:
return None
@@ -205,12 +205,11 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...]
return prefill_metadata
@@ -219,10 +218,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None