[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -4,7 +4,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
@@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendMLAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
metadata_cls)
|
||||
metadata_cls, supports_dcp_with_varlen)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
@@ -92,7 +94,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendPCPMetadata | None:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert common_long_seq_metadata is not None
|
||||
@@ -121,10 +122,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
chunked_context_metadata = super().build_chunked_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
if chunked_context_metadata is None:
|
||||
return None
|
||||
|
||||
@@ -205,12 +205,11 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLAPrefillMetadata:
|
||||
prefill_metadata = super().build_prefill_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
prefill_metadata.pcp_metadata = self.build_cp_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
prefill_metadata.block_table = self.block_table[
|
||||
self.num_decodes_flatten:, ...]
|
||||
return prefill_metadata
|
||||
@@ -219,10 +218,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLADecodeMetadata:
|
||||
decode_metadata = super().build_decode_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert long_seq_metadata is not None
|
||||
|
||||
Reference in New Issue
Block a user