[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)

* [Kernel] add kernels to torch.ops

* [Misc]Specify that DS only supports --kv-cache-dtype bfloat16

---------

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-17 16:52:02 +08:00
committed by GitHub
parent 8988ad08b2
commit 71a5a04e0a
4 changed files with 325 additions and 15 deletions

View File

@@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache(
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
dtype=torch.float32, device=q.device)
xtorch_ops.fwd_kvcache_mla(
torch.ops._C.fwd_kvcache_mla(
q_c=q,
kv_cache=k_cache,
indices=indices,
@@ -224,7 +224,7 @@ def flash_mla_sparse_prefill(
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
xtorch_ops.sparse_prefill_fwd_opt(
torch.ops._C.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,