[Bugsfix] Fix run failed (#198)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -386,8 +386,8 @@ def silu_and_mul_quant_xpu(
|
||||
pass
|
||||
|
||||
|
||||
import kunlun_ops # noqa: E402
|
||||
import torch # noqa: E402
|
||||
import xtorch_ops # noqa: E402
|
||||
from torch.library import custom_op, impl # noqa: E402
|
||||
|
||||
|
||||
@@ -1402,7 +1402,7 @@ def awq_dequantize_cuda(
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
out = kunlun_ops.awq_dequantize(
|
||||
kunlun_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
@@ -2034,8 +2034,7 @@ def I8_paged_mqa_logits(
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
kunlun_ops.sparse_prefill_fwd_opt(
|
||||
.I8_paged_mqa_logits(
|
||||
kunlun_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2351,10 +2350,7 @@ def fast_topkv2(
|
||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||
) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = kunlun_ops.fast_topkv2(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk=topk)
|
||||
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
return topk_indices
|
||||
|
||||
|
||||
@@ -2363,10 +2359,7 @@ def fast_topkv2_cuda(
|
||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||
) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = kunlun_ops.fast_topkv2(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk=topk)
|
||||
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
return topk_indices
|
||||
|
||||
|
||||
@@ -2805,7 +2798,7 @@ def lora_matmul_inplace(
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
) -> None:
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=output_tensor,
|
||||
@@ -2826,7 +2819,7 @@ def lora_matmul_inplace_cuda(
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
) -> None:
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=output_tensor,
|
||||
|
||||
Reference in New Issue
Block a user