fix: custom op fallback forward native when lower sm80 (#1177)

This commit is contained in:
Yineng Zhang
2024-08-22 07:26:35 +10:00
committed by GitHub
parent bea2bb9eea
commit 1fb9459908
2 changed files with 10 additions and 0 deletions

View File

@@ -20,11 +20,18 @@ from vllm.model_executor.custom_op import CustomOp
class SiluAndMul(CustomOp):
def __init__(self, **kwargs):
super().__init__()
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self.is_lower_sm80:
return self.forward_native(x)
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)

View File

@@ -32,12 +32,15 @@ class RMSNorm(CustomOp):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.is_lower_sm80:
return self.forward_native(x, residual)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)