This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

View File

@@ -0,0 +1,42 @@
import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
ctx.save_for_backward(input, weight, rsigma, var)
ctx.eps = eps
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
input, weight, rsigma, var = ctx.saved_tensors
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
grad_output, input, weight, rsigma, var, ctx.eps
)
return grad_input, grad_weight, None
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
return FusedRMSNormFunction.apply(input, weight, eps)
class FusedRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super(FusedRMSNorm, self).__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
output = output.to(dtype)
return output

View File

@@ -0,0 +1,32 @@
import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRopeEmbFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
ctx.offset = offset
return qemb, kemb
@staticmethod
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
grad_input, grad_rope = _torch_vacc.rope_backward(
q_out_grad, k_out_grad, ctx.offset
)
return grad_input, grad_rope, None
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
# return FusedRopeEmbFunction.apply(q, k, offset)
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
class RopeEmb(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
return rope_emb(q, k, offset)