[Bugsfix] Fix run failed (#198)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -386,8 +386,8 @@ def silu_and_mul_quant_xpu(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
import kunlun_ops # noqa: E402
|
||||||
import torch # noqa: E402
|
import torch # noqa: E402
|
||||||
import xtorch_ops # noqa: E402
|
|
||||||
from torch.library import custom_op, impl # noqa: E402
|
from torch.library import custom_op, impl # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@@ -1402,7 +1402,7 @@ def awq_dequantize_cuda(
|
|||||||
device=qweight.device,
|
device=qweight.device,
|
||||||
)
|
)
|
||||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||||
out = kunlun_ops.awq_dequantize(
|
kunlun_ops.awq_dequantize(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
zeros=zeros,
|
zeros=zeros,
|
||||||
@@ -2034,8 +2034,7 @@ def I8_paged_mqa_logits(
|
|||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
kunlun_ops.sparse_prefill_fwd_opt(
|
kunlun_ops.I8_paged_mqa_logits(
|
||||||
.I8_paged_mqa_logits(
|
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=fused_kv_cache,
|
fused_kv_cache=fused_kv_cache,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@@ -2351,10 +2350,7 @@ def fast_topkv2(
|
|||||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = kunlun_ops.fast_topkv2(
|
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||||
score=score,
|
|
||||||
lengths=lengths,
|
|
||||||
topk=topk)
|
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -2363,10 +2359,7 @@ def fast_topkv2_cuda(
|
|||||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = kunlun_ops.fast_topkv2(
|
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||||
score=score,
|
|
||||||
lengths=lengths,
|
|
||||||
topk=topk)
|
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -2805,7 +2798,7 @@ def lora_matmul_inplace(
|
|||||||
alpha: float = 1.0,
|
alpha: float = 1.0,
|
||||||
beta: float = 1.0,
|
beta: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.matmul(
|
kunlun_ops.matmul(
|
||||||
x=x.contiguous(),
|
x=x.contiguous(),
|
||||||
w=w.contiguous(),
|
w=w.contiguous(),
|
||||||
out=output_tensor,
|
out=output_tensor,
|
||||||
@@ -2826,7 +2819,7 @@ def lora_matmul_inplace_cuda(
|
|||||||
alpha: float = 1.0,
|
alpha: float = 1.0,
|
||||||
beta: float = 1.0,
|
beta: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.matmul(
|
kunlun_ops.matmul(
|
||||||
x=x.contiguous(),
|
x=x.contiguous(),
|
||||||
w=w.contiguous(),
|
w=w.contiguous(),
|
||||||
out=output_tensor,
|
out=output_tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user