[Perf] Deepseekv3 performance optimization for eager mode (#598)
### What this PR does / why we need it? Deepseek v3 now adopt vanilla chunked prefill on MLA part which is ineffcient for computing but necessary for chunked prefill. Since PR https://github.com/vllm-project/vllm-ascend/pull/543 bring v0 scheduler into vllm-ascend, we can now adopt torch_npu._npu_flash_attention inside the mla backend for more performance boost. Also there are some redundant computation inside the rope, which is also removed. This PR should bring some performance gain for deepseek eager mode inference. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -25,35 +25,43 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
import torch_npu
|
||||
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0:
|
||||
return torch.ops._C.rotary_embedding(
|
||||
if custom_rotary_embedding_enabled(query, neox_style, self.head_size):
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
@@ -62,33 +70,33 @@ def rope_forward_oot(
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# seq_len = positions.max() + 1
|
||||
seq_len = self.max_position_embeddings
|
||||
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
||||
# self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype)
|
||||
self._set_cos_sin_cache(seq_len=seq_len,
|
||||
device=query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
cos = self.cos_cached[:seq_len].to(dtype=query.dtype)
|
||||
sin = self.sin_cached[:seq_len].to(dtype=query.dtype)
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions)
|
||||
|
||||
def native_rope_deepseek_forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
@@ -190,7 +198,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
seq_len = self.max_position_embeddings
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.rotary_dim
|
||||
|
||||
@@ -214,21 +221,53 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
# _mscale = float(
|
||||
# yarn_get_mscale(self.scaling_factor, self.mscale)
|
||||
# / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
||||
# )
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype),
|
||||
persistent=False)
|
||||
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype),
|
||||
persistent=False)
|
||||
cache = torch.cat([freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale],
|
||||
dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
|
||||
def deepseek_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
_set_cos_sin_cache(self,
|
||||
max_position_embeddings,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
|
||||
|
||||
# TODO: Patch when aclnn ops available
|
||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||
|
||||
# Note: we adopt the native huggingface deepseek rope initialization code from
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
|
||||
# its more ascend compute friendly
|
||||
DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func
|
||||
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
||||
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
||||
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None
|
||||
|
||||
Reference in New Issue
Block a user