unify is_cuda and is_hip (#4321)

This commit is contained in:
Yineng Zhang
2025-03-11 18:12:56 -07:00
committed by GitHub
parent 1cf63485c1
commit d1da58e275
18 changed files with 104 additions and 92 deletions

View File

@@ -1,8 +1,10 @@
import torch
from torch import nn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
class CustomOp(nn.Module):
@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_rocm:
elif _is_hip:
return self.forward_hip
else:
return self.forward_native