[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)
* [Kernel] add kernels to torch.ops * [Misc]Specify that DS only supports --kv-cache-dtype bfloat16 --------- Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import xtorch_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
@@ -29,7 +28,7 @@ def int8_mqa_logits(
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
torch.ops._C.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
@@ -99,7 +98,7 @@ def int8_paged_mqa_logits(
|
||||
|
||||
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
||||
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
torch.ops._C.I8_paged_mqa_logits(
|
||||
q=q_fp8,
|
||||
fused_kv_cache=kv_cache,
|
||||
weights=weights,
|
||||
|
||||
Reference in New Issue
Block a user