[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)
* [Kernel] add kernels to torch.ops * [Misc]Specify that DS only supports --kv-cache-dtype bfloat16 --------- Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache(
|
||||
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
torch.ops._C.fwd_kvcache_mla(
|
||||
q_c=q,
|
||||
kv_cache=k_cache,
|
||||
indices=indices,
|
||||
@@ -224,7 +224,7 @@ def flash_mla_sparse_prefill(
|
||||
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
torch.ops._C.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import xtorch_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
@@ -29,7 +28,7 @@ def int8_mqa_logits(
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
torch.ops._C.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
@@ -99,7 +98,7 @@ def int8_paged_mqa_logits(
|
||||
|
||||
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
||||
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
torch.ops._C.I8_paged_mqa_logits(
|
||||
q=q_fp8,
|
||||
fused_kv_cache=kv_cache,
|
||||
weights=weights,
|
||||
|
||||
Reference in New Issue
Block a user