[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
@@ -13,6 +12,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
@@ -177,7 +177,7 @@ class AscendMLAMetadata:
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMLAMetadataBuilder:
|
||||
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
@@ -186,14 +186,17 @@ class AscendMLAMetadataBuilder:
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendMLAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
||||
AscendMLAMetadata)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
@@ -384,7 +387,7 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
fast_build: bool = False,
|
||||
) -> AscendMLAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
@@ -400,17 +403,6 @@ class AscendMLAMetadataBuilder:
|
||||
self.slot_mapping = common_attn_metadata.slot_mapping[:self.
|
||||
num_actual_tokens]
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
@@ -420,12 +412,12 @@ class AscendMLAMetadataBuilder:
|
||||
prefill_metadata = None
|
||||
if self.num_prefills > 0:
|
||||
prefill_metadata = self.build_prefill_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
decode_metadata = None
|
||||
if self.num_decodes > 0:
|
||||
decode_metadata = self.build_decode_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
||||
@@ -450,7 +442,6 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
if not self.chunked_prefill_enabled:
|
||||
return None
|
||||
@@ -520,7 +511,6 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLAPrefillMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
@@ -530,7 +520,7 @@ class AscendMLAMetadataBuilder:
|
||||
)
|
||||
|
||||
chunked_context_metadata = self.build_chunked_metadata(
|
||||
common_prefix_len, common_attn_metadata, model)
|
||||
common_prefix_len, common_attn_metadata)
|
||||
reqs_start = self.num_decodes # prefill_start
|
||||
tokens_start = self.num_decode_tokens
|
||||
max_query_len = self.query_lens[reqs_start:].max().item()
|
||||
@@ -539,12 +529,7 @@ class AscendMLAMetadataBuilder:
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||
return AscendMLAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||
@@ -564,7 +549,6 @@ class AscendMLAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLADecodeMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
@@ -573,7 +557,6 @@ class AscendMLAMetadataBuilder:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
cos, sin = get_cos_and_sin_mla()
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes +
|
||||
1].tolist()
|
||||
@@ -640,54 +623,25 @@ class AscendMLAMetadataBuilder:
|
||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
|
||||
common_attn_metadata)
|
||||
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
assert self.sin_cache is not None
|
||||
if cos is None and sin is None:
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=self.block_table,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
else:
|
||||
cos[:self.num_decode_tokens,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(
|
||||
2)
|
||||
sin[:self.num_decode_tokens,
|
||||
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(
|
||||
2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=self.block_table,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:self.num_decode_tokens, ...],
|
||||
cos=cos[:self.num_decode_tokens, ...],
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True)
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=self.block_table,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:self.num_decode_tokens, ...],
|
||||
cos=cos[:self.num_decode_tokens, ...],
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
return decode_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
@@ -696,7 +650,6 @@ class AscendMLAMetadataBuilder:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
Reference in New Issue
Block a user