init
This commit is contained in:
0
torch_vacc/fused_ops/__init__.py
Normal file
0
torch_vacc/fused_ops/__init__.py
Normal file
BIN
torch_vacc/fused_ops/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/fused_ops/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/fused_ops/__pycache__/rms_norm.cpython-312.pyc
Normal file
BIN
torch_vacc/fused_ops/__pycache__/rms_norm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/fused_ops/__pycache__/rope_emb.cpython-312.pyc
Normal file
BIN
torch_vacc/fused_ops/__pycache__/rope_emb.cpython-312.pyc
Normal file
Binary file not shown.
42
torch_vacc/fused_ops/rms_norm.py
Normal file
42
torch_vacc/fused_ops/rms_norm.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import torch_vacc
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
|
||||
class FusedRMSNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
||||
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
|
||||
ctx.save_for_backward(input, weight, rsigma, var)
|
||||
ctx.eps = eps
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
input, weight, rsigma, var = ctx.saved_tensors
|
||||
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
|
||||
grad_output, input, weight, rsigma, var, ctx.eps
|
||||
)
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
||||
return FusedRMSNormFunction.apply(input, weight, eps)
|
||||
|
||||
|
||||
class FusedRMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps: float = 1e-6):
|
||||
super(FusedRMSNorm, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
|
||||
|
||||
output = output.to(dtype)
|
||||
return output
|
||||
32
torch_vacc/fused_ops/rope_emb.py
Normal file
32
torch_vacc/fused_ops/rope_emb.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
import torch_vacc
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
|
||||
class FusedRopeEmbFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
|
||||
ctx.offset = offset
|
||||
|
||||
return qemb, kemb
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
|
||||
grad_input, grad_rope = _torch_vacc.rope_backward(
|
||||
q_out_grad, k_out_grad, ctx.offset
|
||||
)
|
||||
return grad_input, grad_rope, None
|
||||
|
||||
|
||||
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||
# return FusedRopeEmbFunction.apply(q, k, offset)
|
||||
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
|
||||
|
||||
|
||||
class RopeEmb(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||
return rope_emb(q, k, offset)
|
||||
Reference in New Issue
Block a user