Files
enginex-vastai-va16-vllm/torch_vacc/fused_ops/rope_emb.py
2026-04-02 04:55:00 +00:00

33 lines
980 B
Python

import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRopeEmbFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
ctx.offset = offset
return qemb, kemb
@staticmethod
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
grad_input, grad_rope = _torch_vacc.rope_backward(
q_out_grad, k_out_grad, ctx.offset
)
return grad_input, grad_rope, None
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
# return FusedRopeEmbFunction.apply(q, k, offset)
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
class RopeEmb(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
return rope_emb(q, k, offset)