38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
|
|
|
||
|
|
import torch.nn as nn
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
class CustomOp(nn.Module):
|
||
|
|
|
||
|
|
def forward_vacc(self, *args, **kwargs):
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
def dispatch_forward(self):
|
||
|
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||
|
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||
|
|
|
||
|
|
enabled = self.enabled()
|
||
|
|
logger.debug("custom op %s %s", self.__class__.name,
|
||
|
|
"enabled" if enabled else "disabled")
|
||
|
|
|
||
|
|
if not enabled:
|
||
|
|
return self.forward_native
|
||
|
|
|
||
|
|
return self.forward
|
||
|
|
|
||
|
|
if current_platform.is_rocm():
|
||
|
|
return self.forward_hip
|
||
|
|
elif current_platform.is_cpu():
|
||
|
|
return self.forward_cpu
|
||
|
|
elif current_platform.is_hpu():
|
||
|
|
return self.forward_hpu
|
||
|
|
elif current_platform.is_tpu():
|
||
|
|
return self.forward_tpu
|
||
|
|
elif current_platform.is_xpu():
|
||
|
|
return self.forward_xpu
|
||
|
|
elif current_platform.is_vacc():
|
||
|
|
return self.forward
|
||
|
|
else:
|
||
|
|
return self.forward_cuda
|