[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -21,7 +21,6 @@ from typing import Optional, Tuple
|
||||
import einops
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
@@ -40,13 +39,15 @@ from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
|
||||
# attn_metadata. This causes that rope in GQA models must pass cos && sin
|
||||
# by different approaches.
|
||||
_cos_mla: Optional[torch.Tensor] = None
|
||||
_sin_mla: Optional[torch.Tensor] = None
|
||||
_cos_sin_cache: Optional[torch.Tensor] = None
|
||||
_cos: Optional[torch.Tensor] = None
|
||||
_sin: Optional[torch.Tensor] = None
|
||||
_cos_slice: Optional[torch.Tensor] = None
|
||||
_sin_slice: Optional[torch.Tensor] = None
|
||||
_cos_mla: torch.Tensor = None
|
||||
_sin_mla: torch.Tensor = None
|
||||
_cos_cache: torch.Tensor = None
|
||||
_sin_cache: torch.Tensor = None
|
||||
_cos_sin_cache: torch.Tensor = None
|
||||
_cos: torch.Tensor = None
|
||||
_sin: torch.Tensor = None
|
||||
_cos_slice: torch.Tensor = None
|
||||
_sin_slice: torch.Tensor = None
|
||||
|
||||
|
||||
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
@@ -62,25 +63,23 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
_sin is not None:
|
||||
return
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
if model_config.use_mla:
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
_cos_mla = torch.ones(max_num_batched_tokens,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin_mla = torch.zeros(max_num_batched_tokens,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
elif not is_vl_model(vllm_config) and has_rope(vllm_config):
|
||||
rope_dim = model_config.get_head_size()
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
@@ -101,8 +100,19 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
device=device)
|
||||
|
||||
|
||||
def get_cos_and_sin_mla():
|
||||
return _cos_mla, _sin_mla
|
||||
def get_cos_and_sin_mla(positions, use_cache=False):
|
||||
global _cos_cache
|
||||
global _sin_cache
|
||||
cos = _cos_cache[positions].unsqueeze(1).unsqueeze(2)
|
||||
sin = _sin_cache[positions].unsqueeze(1).unsqueeze(2)
|
||||
if not use_cache:
|
||||
return cos, sin
|
||||
global _cos_mla
|
||||
global _sin_mla
|
||||
num_tokens = positions.size(0)
|
||||
_cos_mla[:num_tokens, ...] = cos
|
||||
_sin_mla[:num_tokens, ...] = sin
|
||||
return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...]
|
||||
|
||||
|
||||
def _record_cos_sin_cache(cos_sin_cache):
|
||||
@@ -112,6 +122,13 @@ def _record_cos_sin_cache(cos_sin_cache):
|
||||
_cos_sin_cache = cos_sin_cache
|
||||
|
||||
|
||||
def _record_cos_and_sin_cache(cos_cache, sin_cache):
|
||||
global _cos_cache
|
||||
global _sin_cache
|
||||
_cos_cache = cos_cache
|
||||
_sin_cache = sin_cache
|
||||
|
||||
|
||||
def update_cos_sin(positions):
|
||||
global _cos
|
||||
global _sin
|
||||
@@ -469,6 +486,8 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
_record_cos_sin_cache(cache)
|
||||
_record_cos_and_sin_cache(cos_cached, sin_cached)
|
||||
|
||||
def forward(self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user