From 0329fad9276f4b29a4766edc9c00539b05e0592c Mon Sep 17 00:00:00 2001 From: Pleaplusone <38376071+ganyi1996ppo@users.noreply.github.com> Date: Tue, 29 Apr 2025 17:12:03 +0800 Subject: [PATCH] [Perf] Deepseekv3 performance optimization for eager mode (#598) ### What this PR does / why we need it? Deepseek v3 now adopt vanilla chunked prefill on MLA part which is ineffcient for computing but necessary for chunked prefill. Since PR https://github.com/vllm-project/vllm-ascend/pull/543 bring v0 scheduler into vllm-ascend, we can now adopt torch_npu._npu_flash_attention inside the mla backend for more performance boost. Also there are some redundant computation inside the rope, which is also removed. This PR should bring some performance gain for deepseek eager mode inference. --------- Signed-off-by: ganyi --- tests/ops/test_rotary_embedding.py | 5 -- vllm_ascend/attention/mla_v1.py | 134 +++++++++++++++++++--------- vllm_ascend/ops/rotary_embedding.py | 125 +++++++++++++++++--------- vllm_ascend/platform.py | 18 ++-- 4 files changed, 180 insertions(+), 102 deletions(-) diff --git a/tests/ops/test_rotary_embedding.py b/tests/ops/test_rotary_embedding.py index 800960b..2ab0420 100644 --- a/tests/ops/test_rotary_embedding.py +++ b/tests/ops/test_rotary_embedding.py @@ -136,11 +136,6 @@ class RotaryEmbedding(nn.Module): # test with leading dimension and merge seqlen and batch_size as num_tokens -# TODO(ganyi): open this test in the future -@pytest.mark.skip( - reason= - "skip this test by default for now because of ci issue, will enable it in the future" -) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 537c700..64a5431 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata: input_positions: torch.Tensor block_table: torch.Tensor max_query_len: int - max_context_len: int + max_seq_lens: int @dataclass @@ -65,6 +65,7 @@ class AscendMLADecodeMetadata: input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor + max_seq_lens: int @dataclass @@ -131,11 +132,6 @@ class AscendMLAMetadataBuilder: self.runner = runner scheduler_config = runner.scheduler_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - # self.attn_mask = None - # if AscendMLAMetadataBuilder._attn_mask_builder is None: - # AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( - # 128, self.runner.model_config.dtype - # ) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -222,12 +218,14 @@ class AscendMLAMetadataBuilder: num_reqs] seq_lens = seq_lens_cpu max_query_len = query_lens.max().item() - max_context_len = seq_lens.max().item() + max_seq_lens = seq_lens.max().item() prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start tokens_start = self._num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, @@ -236,15 +234,17 @@ class AscendMLAMetadataBuilder: input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, - max_context_len=max_context_len, + max_seq_lens=max_seq_lens, ) decode_metadata = None if self._num_decodes > 0: + max_seq_lens = seq_lens[:self._num_decodes].max().item() decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions[:self._num_decode_tokens], block_table=block_table[:self._num_decode_tokens, ...], - seq_lens=seq_lens[:self._num_decode_tokens]) + seq_lens=seq_lens[:self._num_decode_tokens], + max_seq_lens=max_seq_lens) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -306,12 +306,18 @@ class AscendMLAImpl(MLAAttentionImpl): self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim + # TODO: below padding should be removed after kernel is ready + # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here + # and slice the final result to guarantee its functionality. + self.padding_head_dim = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 + + 1) * 128 # Hack for V1 for now to avoid torch library overhead (since we are # already inside an attention custom op), pull out the forward # method from the rotary embedding and call it directly # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_native + self.rotary_emb = rotary_emb self.q_proj = q_proj self.kv_b_proj = kv_b_proj @@ -409,37 +415,73 @@ class AscendMLAImpl(MLAAttentionImpl): ) -> torch.Tensor: assert attn_metadata.prefill is not None - # TODO: enable this compute for flash attention computation - # kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - # -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]], - # value=0) num_tokens = query.size(0) - attn_output = torch.empty(num_tokens, - self.num_heads, - self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_context_len, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) - attn_output = attn_output.view( + attn_output = None + # Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly + if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: + attn_output = torch.empty(num_tokens, + self.num_heads * self.v_head_dim, + dtype=query.dtype, + device=query.device) + # current requests is chunked in prefill, disable flash attention with chunked prefill + vanilla_chunked_prefill_mla( + output=attn_output, + query=query, + kv_cache=kv_c_and_k_pe_cache, + block_tables=attn_metadata.prefill.block_table, + query_lens=attn_metadata.prefill.query_lens, + context_lens=attn_metadata.prefill.context_lens, + kv_b_proj=self.kv_b_proj, + max_query_len=attn_metadata.prefill.max_query_len, + max_context_len=attn_metadata.prefill.max_seq_lens, + nope_dim=self.qk_nope_head_dim, + rope_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + scale=self.scale, + alibi_slopes=None, + causal=True) + elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + attn_output = torch.empty(num_tokens, + self.num_heads, + self.padding_head_dim, + dtype=query.dtype, + device=query.device) + k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + pad_query = torch.nn.functional.pad(query, [ + 0, self.padding_head_dim - self.qk_rope_head_dim - + self.qk_nope_head_dim + ], + value=0) + pad_key = torch.nn.functional.pad(key, [ + 0, self.padding_head_dim - self.qk_rope_head_dim - + self.qk_nope_head_dim + ], + value=0) + pad_value = torch.nn.functional.pad( + value, [0, self.padding_head_dim - self.v_head_dim], value=0) + torch_npu._npu_flash_attention( + query=pad_query, + key=pad_key, + value=pad_value, + mask=attn_metadata.attn_mask, + seq_len=attn_metadata.prefill.context_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=attn_output) + attn_output = attn_output.view( + -1, self.num_heads, + self.padding_head_dim)[:, :, :self.v_head_dim] + else: + raise RuntimeError( + "Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" + ) + attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) return self.o_proj(attn_output)[0] @@ -457,7 +499,7 @@ class AscendMLAImpl(MLAAttentionImpl): q = torch.cat([q_nope, q_pe], dim=-1) num_tokens = q.size(0) - attn_output = torch.randn( + attn_output = torch.empty( [num_tokens, self.num_heads, self.kv_lora_rank], dtype=q.dtype, device=q.device) @@ -522,8 +564,10 @@ class AscendMLAImpl(MLAAttentionImpl): decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe.contiguous(), - decode_k_pe) + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), + decode_k_pe, + max_seq_len=attn_metadata.decode.max_seq_lens) if has_prefill: assert attn_metadata.prefill is not None @@ -533,7 +577,9 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), prefill_k_pe) + prefill_q_pe.contiguous(), + prefill_k_pe, + max_seq_len=attn_metadata.prefill.max_seq_lens) if kv_cache.numel() > 0: key = torch.cat([ diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index f830364..fbaddc5 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -25,35 +25,43 @@ from vllm.model_executor.layers.rotary_embedding import ( from vllm_ascend.platform import CUSTOM_OP_ENABLED +def custom_rotary_embedding_enabled(query, neox_style, head_size): + return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED + + def rope_forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None ) -> Tuple[torch.Tensor, torch.Tensor]: import torch_npu - + query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + neox_style = self.is_neox_style + if is_neox_style_override is not None: + neox_style = is_neox_style_override # adopt custom kernel path for rotary_embedding - if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0: - return torch.ops._C.rotary_embedding( + if custom_rotary_embedding_enabled(query, neox_style, self.head_size): + query, key = torch.ops._C.rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, - self.is_neox_style, + neox_style, ) + return query.view(query_shape), key.view(key_shape) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: # TODO: Remove the contiguous in the future. - query_shape, key_shape = query.shape, key.shape query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1) torch_npu._npu_rotary_embedding( @@ -62,33 +70,33 @@ def rope_forward_oot( key, self.head_size, self.cos_sin_cache, - self.is_neox_style, + neox_style, ) return query.view(query_shape), key.view(key_shape) -def native_rope_deepseek_forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, -): - # seq_len = positions.max() + 1 - seq_len = self.max_position_embeddings - - # x: [bs, num_attention_heads, seq_len, head_size] - # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - # self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype) - self._set_cos_sin_cache(seq_len=seq_len, - device=query.device, - dtype=query.dtype) - - cos = self.cos_cached[:seq_len].to(dtype=query.dtype) - sin = self.sin_cached[:seq_len].to(dtype=query.dtype) - - q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions) - +def native_rope_deepseek_forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None): + if max_seq_len is not None and max_seq_len > self.max_seq_len: + self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, 2).transpose(3, + 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) + q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, + neox_style) return q_pe, k_pe @@ -190,7 +198,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def _set_cos_sin_cache(self, seq_len, device, dtype): - seq_len = self.max_position_embeddings self.max_seq_len_cached = seq_len dim = self.rotary_dim @@ -214,21 +221,53 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) - - # _mscale = float( - # yarn_get_mscale(self.scaling_factor, self.mscale) - # / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - # ) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), - persistent=False) - self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), - persistent=False) + cache = torch.cat([freqs.cos() * self.mscale, + freqs.sin() * self.mscale], + dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + +def deepseek_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, +) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + self.max_seq_len = max_position_embeddings + _set_cos_sin_cache(self, + max_position_embeddings, + dtype=dtype, + device="npu") -# TODO: Patch when aclnn ops available RotaryEmbedding.forward_oot = rope_forward_oot + +# Note: we adopt the native huggingface deepseek rope initialization code from +# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for +# its more ascend compute friendly +DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward -DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache -DeepseekScalingRotaryEmbedding.max_seq_len_cached = None diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 79e9486..4e4a397 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -31,15 +31,12 @@ try: # register custom ops into torch_library here import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 -except ImportError as e: - if not str( - e - ) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)": - logging.warning( - "Warning: Failed to register custom ops, all custom ops will be disabled" - ) - else: - CUSTOM_OP_ENABLED = True +except ImportError: + logging.warning( + "Warning: Failed to register custom ops, all custom ops will be disabled" + ) +else: + CUSTOM_OP_ENABLED = True if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -180,9 +177,10 @@ class NPUPlatform(Platform): if envs.VLLM_USE_V1: # Activate custom ops for v1. vllm_config.compilation_config.custom_ops = ["all"] - additional_config = vllm_config.additional_config # If ascend_scheduler_config exists in additional_config, # extents original scheduler_config to use AscendScheduler. + + additional_config = vllm_config.additional_config if additional_config and additional_config.get( "ascend_scheduler_config", None) is not None: additional_scheduler_config = additional_config.get(