[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.



version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-28 10:35:07 +08:00
committed by GitHub
parent 24328aaf00
commit dbe4c338f2
10 changed files with 167 additions and 224 deletions

View File

@@ -21,7 +21,6 @@ from typing import Optional, Tuple
import einops
import torch
import torch_npu
from vllm.config import CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
@@ -40,13 +39,15 @@ from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
# attn_metadata. This causes that rope in GQA models must pass cos && sin
# by different approaches.
_cos_mla: Optional[torch.Tensor] = None
_sin_mla: Optional[torch.Tensor] = None
_cos_sin_cache: Optional[torch.Tensor] = None
_cos: Optional[torch.Tensor] = None
_sin: Optional[torch.Tensor] = None
_cos_slice: Optional[torch.Tensor] = None
_sin_slice: Optional[torch.Tensor] = None
_cos_mla: torch.Tensor = None
_sin_mla: torch.Tensor = None
_cos_cache: torch.Tensor = None
_sin_cache: torch.Tensor = None
_cos_sin_cache: torch.Tensor = None
_cos: torch.Tensor = None
_sin: torch.Tensor = None
_cos_slice: torch.Tensor = None
_sin_slice: torch.Tensor = None
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
@@ -62,25 +63,23 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
_sin is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
if model_config.use_mla:
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_batched_tokens,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_batched_tokens,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and has_rope(vllm_config):
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
@@ -101,8 +100,19 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device=device)
def get_cos_and_sin_mla():
return _cos_mla, _sin_mla
def get_cos_and_sin_mla(positions, use_cache=False):
global _cos_cache
global _sin_cache
cos = _cos_cache[positions].unsqueeze(1).unsqueeze(2)
sin = _sin_cache[positions].unsqueeze(1).unsqueeze(2)
if not use_cache:
return cos, sin
global _cos_mla
global _sin_mla
num_tokens = positions.size(0)
_cos_mla[:num_tokens, ...] = cos
_sin_mla[:num_tokens, ...] = sin
return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...]
def _record_cos_sin_cache(cos_sin_cache):
@@ -112,6 +122,13 @@ def _record_cos_sin_cache(cos_sin_cache):
_cos_sin_cache = cos_sin_cache
def _record_cos_and_sin_cache(cos_cache, sin_cache):
global _cos_cache
global _sin_cache
_cos_cache = cos_cache
_sin_cache = sin_cache
def update_cos_sin(positions):
global _cos
global _sin
@@ -469,6 +486,8 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
_record_cos_sin_cache(cache)
_record_cos_and_sin_cache(cos_cached, sin_cached)
def forward(self,
positions: torch.Tensor,