[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)

* [Kernel] add kernels to torch.ops

* [Misc]Specify that DS only supports --kv-cache-dtype bfloat16

---------

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-17 16:52:02 +08:00
committed by GitHub
parent 8988ad08b2
commit 71a5a04e0a
4 changed files with 325 additions and 15 deletions

View File

@@ -1,5 +1,4 @@
import torch
import xtorch_ops
def int8_mqa_logits(
q: torch.Tensor,
@@ -29,7 +28,7 @@ def int8_mqa_logits(
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
xtorch_ops.I8_mqa_logits(
torch.ops._C.I8_mqa_logits(
q=q,
fused_kv_cache=kv,
weights=weights,
@@ -99,7 +98,7 @@ def int8_paged_mqa_logits(
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
xtorch_ops.I8_paged_mqa_logits(
torch.ops._C.I8_paged_mqa_logits(
q=q_fp8,
fused_kv_cache=kv_cache,
weights=weights,