import torch import torch_vacc from torch_vacc._vacc_libs import _torch_vacc class FusedRopeEmbFunction(torch.autograd.Function): @staticmethod def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int): qemb, kemb = _torch_vacc.rope_forward(q, k, offset) ctx.offset = offset return qemb, kemb @staticmethod def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor): grad_input, grad_rope = _torch_vacc.rope_backward( q_out_grad, k_out_grad, ctx.offset ) return grad_input, grad_rope, None def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int): # return FusedRopeEmbFunction.apply(q, k, offset) return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset) class RopeEmb(torch.nn.Module): def __init__(self): super().__init__() def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int): return rope_emb(q, k, offset)