[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)

* [Kernel] add kernels to torch.ops

* [Misc]Specify that DS only supports --kv-cache-dtype bfloat16

---------

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-17 16:52:02 +08:00
committed by GitHub
parent 8988ad08b2
commit 71a5a04e0a
4 changed files with 325 additions and 15 deletions

View File

@@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache(
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
dtype=torch.float32, device=q.device)
xtorch_ops.fwd_kvcache_mla(
torch.ops._C.fwd_kvcache_mla(
q_c=q,
kv_cache=k_cache,
indices=indices,
@@ -224,7 +224,7 @@ def flash_mla_sparse_prefill(
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
xtorch_ops.sparse_prefill_fwd_opt(
torch.ops._C.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,

View File

@@ -1,5 +1,4 @@
import torch
import xtorch_ops
def int8_mqa_logits(
q: torch.Tensor,
@@ -29,7 +28,7 @@ def int8_mqa_logits(
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
xtorch_ops.I8_mqa_logits(
torch.ops._C.I8_mqa_logits(
q=q,
fused_kv_cache=kv,
weights=weights,
@@ -99,7 +98,7 @@ def int8_paged_mqa_logits(
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
xtorch_ops.I8_paged_mqa_logits(
torch.ops._C.I8_paged_mqa_logits(
q=q_fp8,
fused_kv_cache=kv_cache,
weights=weights,

View File

@@ -731,22 +731,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
torch.ops._C.concat_and_cache_mla(
kv_c=k_c_normed,
k_pe=k_pe.squeeze(1),
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
if self.kv_cache_dtype != "fp8_ds_mla":
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
torch.ops._C.concat_and_cache_mla(
kv_c=k_c_normed,
k_pe=k_pe.squeeze(1),
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
attn_metadata)
else:
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
# attn_metadata)
raise NotImplementedError
raise NotImplementedError("Only support --kv-cache-dtype bfloat16")
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output

View File

@@ -1923,3 +1923,315 @@ def apply_repetition_penalties_(
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
##################################################
# --------------- I8_mqa_logits -----------------
##################################################
@custom_op("_C::I8_mqa_logits", mutates_args=())
def I8_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_q_lens: List[torch.Tensor],
context_k_lens: List[torch.Tensor],
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_q_lens=context_q_lens,
context_k_lens=context_k_lens,
logits=logits,
clean_logits=clean_logits,
max_seq_q=max_seq_q,
max_seq_k=max_seq_k,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost,
)
return None
@impl("_C::I8_mqa_logits", "CUDA")
def I8_mqa_logits_cuda(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_q_lens: List[torch.Tensor],
context_k_lens: List[torch.Tensor],
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_q_lens=context_q_lens,
context_k_lens=context_k_lens,
logits=logits,
clean_logits=clean_logits,
max_seq_q=max_seq_q,
max_seq_k=max_seq_k,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost,
)
return None
def _fake_I8_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_q_lens: List[torch.Tensor],
context_k_lens: List[torch.Tensor],
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
return None
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
##################################################
# ------------- I8_paged_mqa_logits --------------
##################################################
@custom_op("_C::I8_paged_mqa_logits", mutates_args=())
def I8_paged_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_lens=context_lens,
block_table=block_table,
max_context_len=max_context_len,
clean_logits=clean_logits,
out=out,
use_xfa_boost=use_xfa_boost)
return None
@impl("_C::I8_paged_mqa_logits", "CUDA")
def I8_paged_mqa_logits_cuda(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_lens=context_lens,
block_table=block_table,
max_context_len=max_context_len,
clean_logits=clean_logits,
out=out,
use_xfa_boost=use_xfa_boost)
return None
def _fake_I8_paged_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
return None
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
##################################################
# ----------- sparse_prefill_fwd_opt -------------
##################################################
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
def sparse_prefill_fwd_opt(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
out=out,
max_logits=max_logits,
lse=lse,
sm_scale=sm_scale,
qlod_cpu=qlod_cpu,
qlod_xpu=qlod_xpu,
kvlod_cpu=kvlod_cpu,
kvlod_xpu=kvlod_xpu,
d_v=d_v,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost)
return None
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
def sparse_prefill_fwd_opt_cuda(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
out=out,
max_logits=max_logits,
lse=lse,
sm_scale=sm_scale,
qlod_cpu=qlod_cpu,
qlod_xpu=qlod_xpu,
kvlod_cpu=kvlod_cpu,
kvlod_xpu=kvlod_xpu,
d_v=d_v,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost)
return None
def _fake_sparse_prefill_fwd_opt(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
return None
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
##################################################
# ------------------ fwd_kvcache_mla -------------
##################################################
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
def fwd_kvcache_mla(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
xtorch_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
indices=indices,
kv_lod_cpu=kv_lod_cpu,
out=out,
max_logits=max_logits,
p_sums=p_sums,
softmax_scale=softmax_scale,
max_seq_kv=max_seq_kv,
q_r=q_r,
pe_cache=pe_cache,
use_xfa_boost=use_xfa_boost,
kv_lod_xpu=kv_lod_xpu)
return None
@impl("_C::fwd_kvcache_mla", "CUDA")
def fwd_kvcache_mla_cuda(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
xtorch_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
indices=indices,
kv_lod_cpu=kv_lod_cpu,
out=out,
max_logits=max_logits,
p_sums=p_sums,
softmax_scale=softmax_scale,
max_seq_kv=max_seq_kv,
q_r=q_r,
pe_cache=pe_cache,
use_xfa_boost=use_xfa_boost,
kv_lod_xpu=kv_lod_xpu)
return None
def _fake_fwd_kvcache_mla(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
return None
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)