[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend import envs
|
||||
@@ -107,7 +108,7 @@ class AscendSFAMetadata:
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFAMetadataBuilder:
|
||||
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
@@ -117,14 +118,17 @@ class AscendSFAMetadataBuilder:
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendSFAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendSFAMetadata # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendSFAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
||||
AscendSFAMetadata)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
@@ -142,9 +146,6 @@ class AscendSFAMetadataBuilder:
|
||||
got {self.decode_threshold}"
|
||||
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
self.enable_sfa_cp = enable_sp() and \
|
||||
hasattr(self.model_config.hf_config, "index_topk")
|
||||
|
||||
@@ -163,7 +164,7 @@ class AscendSFAMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
fast_build: bool = False,
|
||||
) -> AscendSFAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
@@ -178,33 +179,12 @@ class AscendSFAMetadataBuilder:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
has_prefill = any(query_lens_cpu > self.decode_threshold)
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||
|
||||
cos, sin = get_cos_and_sin_mla()
|
||||
|
||||
assert self.cos_cache is not None and self.sin_cache is not None
|
||||
new_cos = self.cos_cache[input_positions][:, None, None]
|
||||
new_sin = self.sin_cache[input_positions][:, None, None]
|
||||
|
||||
if (cos is not None and sin is not None
|
||||
and num_input_tokens <= cos.shape[0]
|
||||
and num_input_tokens <= sin.shape[0]):
|
||||
cos[:num_input_tokens] = new_cos
|
||||
sin[:num_input_tokens] = new_sin
|
||||
if has_prefill:
|
||||
cos, sin = get_cos_and_sin_mla(input_positions)
|
||||
else:
|
||||
cos, sin = new_cos, new_sin
|
||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||
|
||||
sfa_cp_context = None
|
||||
if self.enable_sfa_cp:
|
||||
@@ -299,7 +279,6 @@ class AscendSFAMetadataBuilder:
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
@@ -308,7 +287,6 @@ class AscendSFAMetadataBuilder:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
Reference in New Issue
Block a user