diff --git a/tests/ops/test_rotary_embedding.py b/tests/ops/test_rotary_embedding.py index 800960b..2ab0420 100644 --- a/tests/ops/test_rotary_embedding.py +++ b/tests/ops/test_rotary_embedding.py @@ -136,11 +136,6 @@ class RotaryEmbedding(nn.Module): # test with leading dimension and merge seqlen and batch_size as num_tokens -# TODO(ganyi): open this test in the future -@pytest.mark.skip( - reason= - "skip this test by default for now because of ci issue, will enable it in the future" -) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 537c700..64a5431 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata: input_positions: torch.Tensor block_table: torch.Tensor max_query_len: int - max_context_len: int + max_seq_lens: int @dataclass @@ -65,6 +65,7 @@ class AscendMLADecodeMetadata: input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor + max_seq_lens: int @dataclass @@ -131,11 +132,6 @@ class AscendMLAMetadataBuilder: self.runner = runner scheduler_config = runner.scheduler_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - # self.attn_mask = None - # if AscendMLAMetadataBuilder._attn_mask_builder is None: - # AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( - # 128, self.runner.model_config.dtype - # ) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -222,12 +218,14 @@ class AscendMLAMetadataBuilder: num_reqs] seq_lens = seq_lens_cpu max_query_len = query_lens.max().item() - max_context_len = seq_lens.max().item() + max_seq_lens = seq_lens.max().item() prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start tokens_start = self._num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, @@ -236,15 +234,17 @@ class AscendMLAMetadataBuilder: input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, - max_context_len=max_context_len, + max_seq_lens=max_seq_lens, ) decode_metadata = None if self._num_decodes > 0: + max_seq_lens = seq_lens[:self._num_decodes].max().item() decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions[:self._num_decode_tokens], block_table=block_table[:self._num_decode_tokens, ...], - seq_lens=seq_lens[:self._num_decode_tokens]) + seq_lens=seq_lens[:self._num_decode_tokens], + max_seq_lens=max_seq_lens) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -306,12 +306,18 @@ class AscendMLAImpl(MLAAttentionImpl): self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim + # TODO: below padding should be removed after kernel is ready + # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here + # and slice the final result to guarantee its functionality. + self.padding_head_dim = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 + + 1) * 128 # Hack for V1 for now to avoid torch library overhead (since we are # already inside an attention custom op), pull out the forward # method from the rotary embedding and call it directly # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_native + self.rotary_emb = rotary_emb self.q_proj = q_proj self.kv_b_proj = kv_b_proj @@ -409,37 +415,73 @@ class AscendMLAImpl(MLAAttentionImpl): ) -> torch.Tensor: assert attn_metadata.prefill is not None - # TODO: enable this compute for flash attention computation - # kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - # -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]], - # value=0) num_tokens = query.size(0) - attn_output = torch.empty(num_tokens, - self.num_heads, - self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_context_len, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) - attn_output = attn_output.view( + attn_output = None + # Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly + if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: + attn_output = torch.empty(num_tokens, + self.num_heads * self.v_head_dim, + dtype=query.dtype, + device=query.device) + # current requests is chunked in prefill, disable flash attention with chunked prefill + vanilla_chunked_prefill_mla( + output=attn_output, + query=query, + kv_cache=kv_c_and_k_pe_cache, + block_tables=attn_metadata.prefill.block_table, + query_lens=attn_metadata.prefill.query_lens, + context_lens=attn_metadata.prefill.context_lens, + kv_b_proj=self.kv_b_proj, + max_query_len=attn_metadata.prefill.max_query_len, + max_context_len=attn_metadata.prefill.max_seq_lens, + nope_dim=self.qk_nope_head_dim, + rope_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + scale=self.scale, + alibi_slopes=None, + causal=True) + elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + attn_output = torch.empty(num_tokens, + self.num_heads, + self.padding_head_dim, + dtype=query.dtype, + device=query.device) + k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + pad_query = torch.nn.functional.pad(query, [ + 0, self.padding_head_dim - self.qk_rope_head_dim - + self.qk_nope_head_dim + ], + value=0) + pad_key = torch.nn.functional.pad(key, [ + 0, self.padding_head_dim - self.qk_rope_head_dim - + self.qk_nope_head_dim + ], + value=0) + pad_value = torch.nn.functional.pad( + value, [0, self.padding_head_dim - self.v_head_dim], value=0) + torch_npu._npu_flash_attention( + query=pad_query, + key=pad_key, + value=pad_value, + mask=attn_metadata.attn_mask, + seq_len=attn_metadata.prefill.context_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=attn_output) + attn_output = attn_output.view( + -1, self.num_heads, + self.padding_head_dim)[:, :, :self.v_head_dim] + else: + raise RuntimeError( + "Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" + ) + attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) return self.o_proj(attn_output)[0] @@ -457,7 +499,7 @@ class AscendMLAImpl(MLAAttentionImpl): q = torch.cat([q_nope, q_pe], dim=-1) num_tokens = q.size(0) - attn_output = torch.randn( + attn_output = torch.empty( [num_tokens, self.num_heads, self.kv_lora_rank], dtype=q.dtype, device=q.device) @@ -522,8 +564,10 @@ class AscendMLAImpl(MLAAttentionImpl): decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe.contiguous(), - decode_k_pe) + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), + decode_k_pe, + max_seq_len=attn_metadata.decode.max_seq_lens) if has_prefill: assert attn_metadata.prefill is not None @@ -533,7 +577,9 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), prefill_k_pe) + prefill_q_pe.contiguous(), + prefill_k_pe, + max_seq_len=attn_metadata.prefill.max_seq_lens) if kv_cache.numel() > 0: key = torch.cat([ diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index f830364..fbaddc5 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -25,35 +25,43 @@ from vllm.model_executor.layers.rotary_embedding import ( from vllm_ascend.platform import CUSTOM_OP_ENABLED +def custom_rotary_embedding_enabled(query, neox_style, head_size): + return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED + + def rope_forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None ) -> Tuple[torch.Tensor, torch.Tensor]: import torch_npu - + query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + neox_style = self.is_neox_style + if is_neox_style_override is not None: + neox_style = is_neox_style_override # adopt custom kernel path for rotary_embedding - if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0: - return torch.ops._C.rotary_embedding( + if custom_rotary_embedding_enabled(query, neox_style, self.head_size): + query, key = torch.ops._C.rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, - self.is_neox_style, + neox_style, ) + return query.view(query_shape), key.view(key_shape) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: # TODO: Remove the contiguous in the future. - query_shape, key_shape = query.shape, key.shape query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1) torch_npu._npu_rotary_embedding( @@ -62,33 +70,33 @@ def rope_forward_oot( key, self.head_size, self.cos_sin_cache, - self.is_neox_style, + neox_style, ) return query.view(query_shape), key.view(key_shape) -def native_rope_deepseek_forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, -): - # seq_len = positions.max() + 1 - seq_len = self.max_position_embeddings - - # x: [bs, num_attention_heads, seq_len, head_size] - # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - # self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype) - self._set_cos_sin_cache(seq_len=seq_len, - device=query.device, - dtype=query.dtype) - - cos = self.cos_cached[:seq_len].to(dtype=query.dtype) - sin = self.sin_cached[:seq_len].to(dtype=query.dtype) - - q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions) - +def native_rope_deepseek_forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None): + if max_seq_len is not None and max_seq_len > self.max_seq_len: + self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, 2).transpose(3, + 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) + q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, + neox_style) return q_pe, k_pe @@ -190,7 +198,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def _set_cos_sin_cache(self, seq_len, device, dtype): - seq_len = self.max_position_embeddings self.max_seq_len_cached = seq_len dim = self.rotary_dim @@ -214,21 +221,53 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) - - # _mscale = float( - # yarn_get_mscale(self.scaling_factor, self.mscale) - # / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - # ) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), - persistent=False) - self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), - persistent=False) + cache = torch.cat([freqs.cos() * self.mscale, + freqs.sin() * self.mscale], + dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + +def deepseek_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, +) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + self.max_seq_len = max_position_embeddings + _set_cos_sin_cache(self, + max_position_embeddings, + dtype=dtype, + device="npu") -# TODO: Patch when aclnn ops available RotaryEmbedding.forward_oot = rope_forward_oot + +# Note: we adopt the native huggingface deepseek rope initialization code from +# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for +# its more ascend compute friendly +DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward -DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache -DeepseekScalingRotaryEmbedding.max_seq_len_cached = None diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 79e9486..4e4a397 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -31,15 +31,12 @@ try: # register custom ops into torch_library here import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 -except ImportError as e: - if not str( - e - ) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)": - logging.warning( - "Warning: Failed to register custom ops, all custom ops will be disabled" - ) - else: - CUSTOM_OP_ENABLED = True +except ImportError: + logging.warning( + "Warning: Failed to register custom ops, all custom ops will be disabled" + ) +else: + CUSTOM_OP_ENABLED = True if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -180,9 +177,10 @@ class NPUPlatform(Platform): if envs.VLLM_USE_V1: # Activate custom ops for v1. vllm_config.compilation_config.custom_ops = ["all"] - additional_config = vllm_config.additional_config # If ascend_scheduler_config exists in additional_config, # extents original scheduler_config to use AscendScheduler. + + additional_config = vllm_config.additional_config if additional_config and additional_config.get( "ascend_scheduler_config", None) is not None: additional_scheduler_config = additional_config.get(