Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
@@ -31,7 +31,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Optional
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
@@ -149,5 +149,5 @@ def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
out = torch.empty_like(x)
|
||||
xtorch_ops.l2norm(x, out, eps)
|
||||
kunlun_ops.l2norm(x, out, eps)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user