[Perf] Deepseekv3 performance optimization for eager mode (#598)

### What this PR does / why we need it?
Deepseek v3 now adopt vanilla chunked prefill on MLA part which is
ineffcient for computing but necessary for chunked prefill. Since PR
https://github.com/vllm-project/vllm-ascend/pull/543 bring v0 scheduler
into vllm-ascend, we can now adopt torch_npu._npu_flash_attention inside
the mla backend for more performance boost. Also there are some
redundant computation inside the rope, which is also removed. This PR
should bring some performance gain for deepseek eager mode inference.

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-04-29 17:12:03 +08:00
committed by GitHub
parent 87975fa058
commit 0329fad927
4 changed files with 180 additions and 102 deletions

View File

@@ -136,11 +136,6 @@ class RotaryEmbedding(nn.Module):
# test with leading dimension and merge seqlen and batch_size as num_tokens # test with leading dimension and merge seqlen and batch_size as num_tokens
# TODO(ganyi): open this test in the future
@pytest.mark.skip(
reason=
"skip this test by default for now because of ci issue, will enable it in the future"
)
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("seq_len", SEQ_LENS)

View File

@@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata:
input_positions: torch.Tensor input_positions: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
max_query_len: int max_query_len: int
max_context_len: int max_seq_lens: int
@dataclass @dataclass
@@ -65,6 +65,7 @@ class AscendMLADecodeMetadata:
input_positions: torch.Tensor input_positions: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
max_seq_lens: int
@dataclass @dataclass
@@ -131,11 +132,6 @@ class AscendMLAMetadataBuilder:
self.runner = runner self.runner = runner
scheduler_config = runner.scheduler_config scheduler_config = runner.scheduler_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
# self.attn_mask = None
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
# 128, self.runner.model_config.dtype
# )
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
@@ -222,12 +218,14 @@ class AscendMLAMetadataBuilder:
num_reqs] num_reqs]
seq_lens = seq_lens_cpu seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item() max_query_len = query_lens.max().item()
max_context_len = seq_lens.max().item() max_seq_lens = seq_lens.max().item()
prefill_metadata = None prefill_metadata = None
if self._num_prefills > 0: if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens tokens_start = self._num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item()
prefill_metadata = AscendMLAPrefillMetadata( prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask, attn_mask=self.runner.attn_mask,
@@ -236,15 +234,17 @@ class AscendMLAMetadataBuilder:
input_positions=input_positions[tokens_start:], input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...], block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len, max_query_len=max_query_len,
max_context_len=max_context_len, max_seq_lens=max_seq_lens,
) )
decode_metadata = None decode_metadata = None
if self._num_decodes > 0: if self._num_decodes > 0:
max_seq_lens = seq_lens[:self._num_decodes].max().item()
decode_metadata = AscendMLADecodeMetadata( decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions[:self._num_decode_tokens], input_positions=input_positions[:self._num_decode_tokens],
block_table=block_table[:self._num_decode_tokens, ...], block_table=block_table[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:self._num_decode_tokens]) seq_lens=seq_lens[:self._num_decode_tokens],
max_seq_lens=max_seq_lens)
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@@ -306,12 +306,18 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
# TODO: below padding should be removed after kernel is ready
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
# and slice the final result to guarantee its functionality.
self.padding_head_dim = (
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
1) * 128
# Hack for V1 for now to avoid torch library overhead (since we are # Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward # already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly # method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this # TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb.forward_native self.rotary_emb = rotary_emb
self.q_proj = q_proj self.q_proj = q_proj
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
@@ -409,17 +415,12 @@ class AscendMLAImpl(MLAAttentionImpl):
) -> torch.Tensor: ) -> torch.Tensor:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
# TODO: enable this compute for flash attention computation
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
# value=0)
num_tokens = query.size(0) num_tokens = query.size(0)
attn_output = None
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
attn_output = torch.empty(num_tokens, attn_output = torch.empty(num_tokens,
self.num_heads, self.num_heads * self.v_head_dim,
self.v_head_dim,
dtype=query.dtype, dtype=query.dtype,
device=query.device) device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill # current requests is chunked in prefill, disable flash attention with chunked prefill
@@ -432,14 +433,55 @@ class AscendMLAImpl(MLAAttentionImpl):
context_lens=attn_metadata.prefill.context_lens, context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj, kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len, max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_context_len, max_context_len=attn_metadata.prefill.max_seq_lens,
nope_dim=self.qk_nope_head_dim, nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim, rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim, v_head_dim=self.v_head_dim,
scale=self.scale, scale=self.scale,
alibi_slopes=None, alibi_slopes=None,
causal=True) causal=True)
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.padding_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
pad_query = torch.nn.functional.pad(query, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_key = torch.nn.functional.pad(key, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_value = torch.nn.functional.pad(
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
torch_npu._npu_flash_attention(
query=pad_query,
key=pad_key,
value=pad_value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_heads,
self.padding_head_dim)[:, :, :self.v_head_dim]
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim]) [num_tokens, self.num_heads * self.v_head_dim])
return self.o_proj(attn_output)[0] return self.o_proj(attn_output)[0]
@@ -457,7 +499,7 @@ class AscendMLAImpl(MLAAttentionImpl):
q = torch.cat([q_nope, q_pe], dim=-1) q = torch.cat([q_nope, q_pe], dim=-1)
num_tokens = q.size(0) num_tokens = q.size(0)
attn_output = torch.randn( attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank], [num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype, dtype=q.dtype,
device=q.device) device=q.device)
@@ -522,8 +564,10 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_ql_nope, decode_q_pe = \ decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c) self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(), attn_metadata.decode.input_positions,
decode_k_pe) decode_q_pe.contiguous(),
decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens)
if has_prefill: if has_prefill:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
@@ -533,7 +577,9 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions, attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_k_pe) prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
key = torch.cat([ key = torch.cat([

View File

@@ -25,35 +25,43 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm_ascend.platform import CUSTOM_OP_ENABLED from vllm_ascend.platform import CUSTOM_OP_ENABLED
def custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED
def rope_forward_oot( def rope_forward_oot(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu import torch_npu
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device: if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device) self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype: if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding # adopt custom kernel path for rotary_embedding
if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0: if custom_rotary_embedding_enabled(query, neox_style, self.head_size):
return torch.ops._C.rotary_embedding( query, key = torch.ops._C.rotary_embedding(
positions, positions,
query, query,
key, key,
self.head_size, self.head_size,
self.cos_sin_cache, self.cos_sin_cache,
self.is_neox_style, neox_style,
) )
return query.view(query_shape), key.view(key_shape)
if offsets is not None: if offsets is not None:
raise NotImplementedError( raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.") "Batched rotary embedding is currently not supported on NPU.")
else: else:
# TODO: Remove the contiguous in the future. # TODO: Remove the contiguous in the future.
query_shape, key_shape = query.shape, key.shape
query = query.contiguous().view(query.shape[0], -1) query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1) key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding( torch_npu._npu_rotary_embedding(
@@ -62,33 +70,33 @@ def rope_forward_oot(
key, key,
self.head_size, self.head_size,
self.cos_sin_cache, self.cos_sin_cache,
self.is_neox_style, neox_style,
) )
return query.view(query_shape), key.view(key_shape) return query.view(query_shape), key.view(key_shape)
def native_rope_deepseek_forward( def native_rope_deepseek_forward(self,
self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
): max_seq_len: Optional[int] = None):
# seq_len = positions.max() + 1 if max_seq_len is not None and max_seq_len > self.max_seq_len:
seq_len = self.max_position_embeddings self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
if len(key.shape) == 2:
# x: [bs, num_attention_heads, seq_len, head_size] key = key[:, None, :]
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: # Note: we implement the non neox_style method with shuffle the last dim and neox style
# self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype) # calculation method which is also more compute friendly to the ascend machine
self._set_cos_sin_cache(seq_len=seq_len, # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
device=query.device, neox_style = True
dtype=query.dtype) if self.is_neox_style is False:
b, h_q, d = query.shape
cos = self.cos_cached[:seq_len].to(dtype=query.dtype) query = query.view(b, h_q, d // 2, 2).transpose(3,
sin = self.sin_cached[:seq_len].to(dtype=query.dtype) 2).reshape(b, h_q, d)
b, h_k, d = key.shape
q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions) key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe return q_pe, k_pe
@@ -190,7 +198,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
seq_len = self.max_position_embeddings
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
dim = self.rotary_dim dim = self.rotary_dim
@@ -214,21 +221,53 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
t = torch.arange(seq_len, device=device, dtype=torch.float32) t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq) freqs = torch.outer(t, inv_freq)
cache = torch.cat([freqs.cos() * self.mscale,
# _mscale = float( freqs.sin() * self.mscale],
# yarn_get_mscale(self.scaling_factor, self.mscale) dim=-1).to(dtype)
# / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) self.register_buffer("cos_sin_cache", cache, persistent=False)
# )
emb = torch.cat((freqs, freqs), dim=-1) def deepseek_rope_init_func(
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), self,
persistent=False) head_size: int,
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), rotary_dim: int,
persistent=False) max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
_set_cos_sin_cache(self,
max_position_embeddings,
dtype=dtype,
device="npu")
# TODO: Patch when aclnn ops available
RotaryEmbedding.forward_oot = rope_forward_oot RotaryEmbedding.forward_oot = rope_forward_oot
# Note: we adopt the native huggingface deepseek rope initialization code from
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
# its more ascend compute friendly
DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None

View File

@@ -31,14 +31,11 @@ try:
# register custom ops into torch_library here # register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
except ImportError as e: except ImportError:
if not str(
e
) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)":
logging.warning( logging.warning(
"Warning: Failed to register custom ops, all custom ops will be disabled" "Warning: Failed to register custom ops, all custom ops will be disabled"
) )
else: else:
CUSTOM_OP_ENABLED = True CUSTOM_OP_ENABLED = True
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -180,9 +177,10 @@ class NPUPlatform(Platform):
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
# Activate custom ops for v1. # Activate custom ops for v1.
vllm_config.compilation_config.custom_ops = ["all"] vllm_config.compilation_config.custom_ops = ["all"]
additional_config = vllm_config.additional_config
# If ascend_scheduler_config exists in additional_config, # If ascend_scheduler_config exists in additional_config,
# extents original scheduler_config to use AscendScheduler. # extents original scheduler_config to use AscendScheduler.
additional_config = vllm_config.additional_config
if additional_config and additional_config.get( if additional_config and additional_config.get(
"ascend_scheduler_config", None) is not None: "ascend_scheduler_config", None) is not None:
additional_scheduler_config = additional_config.get( additional_scheduler_config = additional_config.get(