Add initial support for intel Gaudi accelerators (#2121)

This commit is contained in:
Ankur Neog
2024-11-23 09:52:23 +05:30
committed by GitHub
parent 66d4859acf
commit 865233e256
4 changed files with 10 additions and 7 deletions

View File

@@ -176,14 +176,15 @@ class ModelRunner:
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
# Init torch distributed
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
torch.cuda.set_device(self.gpu_id)
backend = "nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif self.device == "xpu":
torch.xpu.set_device(self.gpu_id)
backend = "gloo"
elif self.device == "hpu":
backend = "hccl"
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)