unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
|
||||
def dispatch_forward(self):
|
||||
if _is_cuda:
|
||||
return self.forward_cuda
|
||||
elif _is_rocm:
|
||||
elif _is_hip:
|
||||
return self.forward_hip
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
Reference in New Issue
Block a user