hotfix: add CustomOp abstraction (#1027)

This commit is contained in:
Yineng Zhang
2024-08-11 17:45:59 +08:00
committed by GitHub
parent 9dae407812
commit c245b78973
2 changed files with 7 additions and 4 deletions

View File

@@ -13,15 +13,17 @@ limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from flashinfer.activation import silu_and_mul
from vllm.model_executor.custom_op import CustomOp
class SiluAndMul(nn.Module):
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)

View File

@@ -18,9 +18,10 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from vllm.model_executor.custom_op import CustomOp
class RMSNorm(nn.Module):
class RMSNorm(CustomOp):
def __init__(
self,
hidden_size: int,
@@ -30,7 +31,7 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,