[Perf] Deepseekv3 performance optimization for eager mode (#598)

### What this PR does / why we need it?
Deepseek v3 now adopt vanilla chunked prefill on MLA part which is
ineffcient for computing but necessary for chunked prefill. Since PR
https://github.com/vllm-project/vllm-ascend/pull/543 bring v0 scheduler
into vllm-ascend, we can now adopt torch_npu._npu_flash_attention inside
the mla backend for more performance boost. Also there are some
redundant computation inside the rope, which is also removed. This PR
should bring some performance gain for deepseek eager mode inference.

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-04-29 17:12:03 +08:00
committed by GitHub
parent 87975fa058
commit 0329fad927
4 changed files with 180 additions and 102 deletions

View File

@@ -25,35 +25,43 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm_ascend.platform import CUSTOM_OP_ENABLED
def custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED
def rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0:
return torch.ops._C.rotary_embedding(
if custom_rotary_embedding_enabled(query, neox_style, self.head_size):
query, key = torch.ops._C.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
query_shape, key_shape = query.shape, key.shape
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
@@ -62,33 +70,33 @@ def rope_forward_oot(
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
def native_rope_deepseek_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
):
# seq_len = positions.max() + 1
seq_len = self.max_position_embeddings
# x: [bs, num_attention_heads, seq_len, head_size]
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
# self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype)
self._set_cos_sin_cache(seq_len=seq_len,
device=query.device,
dtype=query.dtype)
cos = self.cos_cached[:seq_len].to(dtype=query.dtype)
sin = self.sin_cached[:seq_len].to(dtype=query.dtype)
q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions)
def native_rope_deepseek_forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2, 2).transpose(3,
2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe
@@ -190,7 +198,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def _set_cos_sin_cache(self, seq_len, device, dtype):
seq_len = self.max_position_embeddings
self.max_seq_len_cached = seq_len
dim = self.rotary_dim
@@ -214,21 +221,53 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
# _mscale = float(
# yarn_get_mscale(self.scaling_factor, self.mscale)
# / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
# )
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype),
persistent=False)
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype),
persistent=False)
cache = torch.cat([freqs.cos() * self.mscale,
freqs.sin() * self.mscale],
dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def deepseek_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
_set_cos_sin_cache(self,
max_position_embeddings,
dtype=dtype,
device="npu")
# TODO: Patch when aclnn ops available
RotaryEmbedding.forward_oot = rope_forward_oot
# Note: we adopt the native huggingface deepseek rope initialization code from
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
# its more ascend compute friendly
DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None