import torch.nn as nn from vllm.logger import init_logger logger = init_logger(__name__) class CustomOp(nn.Module): def forward_vacc(self, *args, **kwargs): raise NotImplementedError def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. enabled = self.enabled() logger.debug("custom op %s %s", self.__class__.name, "enabled" if enabled else "disabled") if not enabled: return self.forward_native return self.forward if current_platform.is_rocm(): return self.forward_hip elif current_platform.is_cpu(): return self.forward_cpu elif current_platform.is_hpu(): return self.forward_hpu elif current_platform.is_tpu(): return self.forward_tpu elif current_platform.is_xpu(): return self.forward_xpu elif current_platform.is_vacc(): return self.forward else: return self.forward_cuda