init
This commit is contained in:
32
torch_vacc/fused_ops/rope_emb.py
Normal file
32
torch_vacc/fused_ops/rope_emb.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
import torch_vacc
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
|
||||
class FusedRopeEmbFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
|
||||
ctx.offset = offset
|
||||
|
||||
return qemb, kemb
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
|
||||
grad_input, grad_rope = _torch_vacc.rope_backward(
|
||||
q_out_grad, k_out_grad, ctx.offset
|
||||
)
|
||||
return grad_input, grad_rope, None
|
||||
|
||||
|
||||
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||
# return FusedRopeEmbFunction.apply(q, k, offset)
|
||||
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
|
||||
|
||||
|
||||
class RopeEmb(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||
return rope_emb(q, k, offset)
|
||||
Reference in New Issue
Block a user