diff --git a/vllm_kunlun/ops/attention/flashmla.py b/vllm_kunlun/ops/attention/flashmla.py index f53dfcf..2375799 100644 --- a/vllm_kunlun/ops/attention/flashmla.py +++ b/vllm_kunlun/ops/attention/flashmla.py @@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache( p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q], dtype=torch.float32, device=q.device) - xtorch_ops.fwd_kvcache_mla( + torch.ops._C.fwd_kvcache_mla( q_c=q, kv_cache=k_cache, indices=indices, @@ -224,7 +224,7 @@ def flash_mla_sparse_prefill( max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device) lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device) - xtorch_ops.sparse_prefill_fwd_opt( + torch.ops._C.sparse_prefill_fwd_opt( q=q, kv=kv, indices=indices, diff --git a/vllm_kunlun/ops/deep_gemm.py b/vllm_kunlun/ops/deep_gemm.py index af0640f..bd1318c 100644 --- a/vllm_kunlun/ops/deep_gemm.py +++ b/vllm_kunlun/ops/deep_gemm.py @@ -1,5 +1,4 @@ import torch -import xtorch_ops def int8_mqa_logits( q: torch.Tensor, @@ -29,7 +28,7 @@ def int8_mqa_logits( context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device) context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device) - xtorch_ops.I8_mqa_logits( + torch.ops._C.I8_mqa_logits( q=q, fused_kv_cache=kv, weights=weights, @@ -99,7 +98,7 @@ def int8_paged_mqa_logits( logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device) - xtorch_ops.I8_paged_mqa_logits( + torch.ops._C.I8_paged_mqa_logits( q=q_fp8, fused_kv_cache=kv_cache, weights=weights, diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py index 3a3a4b5..d5eea23 100644 --- a/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py @@ -731,22 +731,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): q = torch.cat([ql_nope, q_pe], dim=-1) - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - torch.ops._C.concat_and_cache_mla( - kv_c=k_c_normed, - k_pe=k_pe.squeeze(1), - kv_cache=kv_cache, - slot_mapping=attn_metadata.slot_mapping.flatten(), - ) - if self.kv_cache_dtype != "fp8_ds_mla": + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + torch.ops._C.concat_and_cache_mla( + kv_c=k_c_normed, + k_pe=k_pe.squeeze(1), + kv_cache=kv_cache, + slot_mapping=attn_metadata.slot_mapping.flatten(), + ) attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) else: # attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, # attn_metadata) - raise NotImplementedError + raise NotImplementedError("Only support --kv-cache-dtype bfloat16") self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index cac5cf4..ca7273c 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1923,3 +1923,315 @@ def apply_repetition_penalties_( # If logits are positive, divide by penalty, otherwise multiply by penalty. scaling = torch.where(logits > 0, 1.0 / penalties, penalties) logits *= scaling + +################################################## +# --------------- I8_mqa_logits ----------------- +################################################## +@custom_op("_C::I8_mqa_logits", mutates_args=()) +def I8_mqa_logits( + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_q_lens: List[torch.Tensor], + context_k_lens: List[torch.Tensor], + logits: torch.Tensor, + clean_logits: bool, + max_seq_q: Optional[int] = 0, + max_seq_k: Optional[int] = 0, + is_causal: Optional[bool] = False, + use_xfa_boost: Optional[bool] = False, + ) -> None: + xtorch_ops.I8_mqa_logits( + q=q, + fused_kv_cache=fused_kv_cache, + weights=weights, + context_q_lens=context_q_lens, + context_k_lens=context_k_lens, + logits=logits, + clean_logits=clean_logits, + max_seq_q=max_seq_q, + max_seq_k=max_seq_k, + is_causal=is_causal, + use_xfa_boost=use_xfa_boost, + ) + return None + +@impl("_C::I8_mqa_logits", "CUDA") +def I8_mqa_logits_cuda( + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_q_lens: List[torch.Tensor], + context_k_lens: List[torch.Tensor], + logits: torch.Tensor, + clean_logits: bool, + max_seq_q: Optional[int] = 0, + max_seq_k: Optional[int] = 0, + is_causal: Optional[bool] = False, + use_xfa_boost: Optional[bool] = False, + ) -> None: + xtorch_ops.I8_mqa_logits( + q=q, + fused_kv_cache=fused_kv_cache, + weights=weights, + context_q_lens=context_q_lens, + context_k_lens=context_k_lens, + logits=logits, + clean_logits=clean_logits, + max_seq_q=max_seq_q, + max_seq_k=max_seq_k, + is_causal=is_causal, + use_xfa_boost=use_xfa_boost, + ) + return None + +def _fake_I8_mqa_logits( + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_q_lens: List[torch.Tensor], + context_k_lens: List[torch.Tensor], + logits: torch.Tensor, + clean_logits: bool, + max_seq_q: Optional[int] = 0, + max_seq_k: Optional[int] = 0, + is_causal: Optional[bool] = False, + use_xfa_boost: Optional[bool] = False, + ) -> None: + return None + +I8_mqa_logits.register_fake(_fake_I8_mqa_logits) + +################################################## +# ------------- I8_paged_mqa_logits -------------- +################################################## +@custom_op("_C::I8_paged_mqa_logits", mutates_args=()) +def I8_paged_mqa_logits( + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_lens: List[torch.Tensor], + block_table: torch.Tensor, + max_context_len: int, + clean_logits: bool, + out: torch.Tensor, + use_xfa_boost: Optional[bool] = False) -> None: + xtorch_ops.I8_paged_mqa_logits( + q=q, + fused_kv_cache=fused_kv_cache, + weights=weights, + context_lens=context_lens, + block_table=block_table, + max_context_len=max_context_len, + clean_logits=clean_logits, + out=out, + use_xfa_boost=use_xfa_boost) + return None + +@impl("_C::I8_paged_mqa_logits", "CUDA") +def I8_paged_mqa_logits_cuda( + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_lens: List[torch.Tensor], + block_table: torch.Tensor, + max_context_len: int, + clean_logits: bool, + out: torch.Tensor, + use_xfa_boost: Optional[bool] = False) -> None: + xtorch_ops.I8_paged_mqa_logits( + q=q, + fused_kv_cache=fused_kv_cache, + weights=weights, + context_lens=context_lens, + block_table=block_table, + max_context_len=max_context_len, + clean_logits=clean_logits, + out=out, + use_xfa_boost=use_xfa_boost) + return None + +def _fake_I8_paged_mqa_logits( + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_lens: List[torch.Tensor], + block_table: torch.Tensor, + max_context_len: int, + clean_logits: bool, + out: torch.Tensor, + use_xfa_boost: Optional[bool] = False) -> None: + return None + +I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits) + +################################################## +# ----------- sparse_prefill_fwd_opt ------------- +################################################## +@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=()) +def sparse_prefill_fwd_opt( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + qlod_cpu: Optional[torch.Tensor] = None, + qlod_xpu: Optional[torch.Tensor] = None, + kvlod_cpu: Optional[torch.Tensor] = None, + kvlod_xpu: Optional[torch.Tensor] = None, + d_v: Optional[int] = -1, + is_causal: Optional[bool] = True, + use_xfa_boost: Optional[bool] = False) -> None: + xtorch_ops.sparse_prefill_fwd_opt( + q=q, + kv=kv, + indices=indices, + out=out, + max_logits=max_logits, + lse=lse, + sm_scale=sm_scale, + qlod_cpu=qlod_cpu, + qlod_xpu=qlod_xpu, + kvlod_cpu=kvlod_cpu, + kvlod_xpu=kvlod_xpu, + d_v=d_v, + is_causal=is_causal, + use_xfa_boost=use_xfa_boost) + return None + +@impl("_C::sparse_prefill_fwd_opt", "CUDA") +def sparse_prefill_fwd_opt_cuda( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + qlod_cpu: Optional[torch.Tensor] = None, + qlod_xpu: Optional[torch.Tensor] = None, + kvlod_cpu: Optional[torch.Tensor] = None, + kvlod_xpu: Optional[torch.Tensor] = None, + d_v: Optional[int] = -1, + is_causal: Optional[bool] = True, + use_xfa_boost: Optional[bool] = False) -> None: + xtorch_ops.sparse_prefill_fwd_opt( + q=q, + kv=kv, + indices=indices, + out=out, + max_logits=max_logits, + lse=lse, + sm_scale=sm_scale, + qlod_cpu=qlod_cpu, + qlod_xpu=qlod_xpu, + kvlod_cpu=kvlod_cpu, + kvlod_xpu=kvlod_xpu, + d_v=d_v, + is_causal=is_causal, + use_xfa_boost=use_xfa_boost) + return None + +def _fake_sparse_prefill_fwd_opt( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + qlod_cpu: Optional[torch.Tensor] = None, + qlod_xpu: Optional[torch.Tensor] = None, + kvlod_cpu: Optional[torch.Tensor] = None, + kvlod_xpu: Optional[torch.Tensor] = None, + d_v: Optional[int] = -1, + is_causal: Optional[bool] = True, + use_xfa_boost: Optional[bool] = False) -> None: + return None + +sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt) + +################################################## +# ------------------ fwd_kvcache_mla ------------- +################################################## +@custom_op("_C::fwd_kvcache_mla", mutates_args=()) +def fwd_kvcache_mla( + q_c: torch.Tensor, + kv_cache: torch.Tensor, + indices: torch.Tensor, + kv_lod_cpu: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + p_sums: torch.Tensor, + softmax_scale: float, + max_seq_kv: int, + q_r: Optional[torch.Tensor] = None, + pe_cache: Optional[torch.Tensor] = None, + use_xfa_boost: Optional[bool] = False, + kv_lod_xpu: Optional[torch.Tensor] = None) -> None: + xtorch_ops.fwd_kvcache_mla( + q_c=q_c, + kv_cache=kv_cache, + indices=indices, + kv_lod_cpu=kv_lod_cpu, + out=out, + max_logits=max_logits, + p_sums=p_sums, + softmax_scale=softmax_scale, + max_seq_kv=max_seq_kv, + q_r=q_r, + pe_cache=pe_cache, + use_xfa_boost=use_xfa_boost, + kv_lod_xpu=kv_lod_xpu) + return None + +@impl("_C::fwd_kvcache_mla", "CUDA") +def fwd_kvcache_mla_cuda( + q_c: torch.Tensor, + kv_cache: torch.Tensor, + indices: torch.Tensor, + kv_lod_cpu: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + p_sums: torch.Tensor, + softmax_scale: float, + max_seq_kv: int, + q_r: Optional[torch.Tensor] = None, + pe_cache: Optional[torch.Tensor] = None, + use_xfa_boost: Optional[bool] = False, + kv_lod_xpu: Optional[torch.Tensor] = None) -> None: + xtorch_ops.fwd_kvcache_mla( + q_c=q_c, + kv_cache=kv_cache, + indices=indices, + kv_lod_cpu=kv_lod_cpu, + out=out, + max_logits=max_logits, + p_sums=p_sums, + softmax_scale=softmax_scale, + max_seq_kv=max_seq_kv, + q_r=q_r, + pe_cache=pe_cache, + use_xfa_boost=use_xfa_boost, + kv_lod_xpu=kv_lod_xpu) + return None + +def _fake_fwd_kvcache_mla( + q_c: torch.Tensor, + kv_cache: torch.Tensor, + indices: torch.Tensor, + kv_lod_cpu: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + p_sums: torch.Tensor, + softmax_scale: float, + max_seq_kv: int, + q_r: Optional[torch.Tensor] = None, + pe_cache: Optional[torch.Tensor] = None, + use_xfa_boost: Optional[bool] = False, + kv_lod_xpu: Optional[torch.Tensor] = None) -> None: + return None + +fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)