33 lines
980 B
Python
33 lines
980 B
Python
import torch
|
|
import torch_vacc
|
|
from torch_vacc._vacc_libs import _torch_vacc
|
|
|
|
|
|
class FusedRopeEmbFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
|
|
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
|
|
ctx.offset = offset
|
|
|
|
return qemb, kemb
|
|
|
|
@staticmethod
|
|
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
|
|
grad_input, grad_rope = _torch_vacc.rope_backward(
|
|
q_out_grad, k_out_grad, ctx.offset
|
|
)
|
|
return grad_input, grad_rope, None
|
|
|
|
|
|
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
|
|
# return FusedRopeEmbFunction.apply(q, k, offset)
|
|
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
|
|
|
|
|
|
class RopeEmb(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
|
|
return rope_emb(q, k, offset)
|