[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)
* [Kernel] add kernels to torch.ops * [Misc]Specify that DS only supports --kv-cache-dtype bfloat16 --------- Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache(
|
||||
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
torch.ops._C.fwd_kvcache_mla(
|
||||
q_c=q,
|
||||
kv_cache=k_cache,
|
||||
indices=indices,
|
||||
@@ -224,7 +224,7 @@ def flash_mla_sparse_prefill(
|
||||
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
torch.ops._C.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
|
||||
Reference in New Issue
Block a user