[Feature, Hardware] Enable SGLang on XPU GPUs via PyTorch (#1480)

This commit is contained in:
Zhang, Liangang
2024-10-13 02:10:32 +08:00
committed by GitHub
parent e37cdab0c6
commit 5d638c92f5
8 changed files with 55 additions and 19 deletions

View File

@@ -118,7 +118,7 @@ class ForwardBatch:
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = "cuda"
device = model_runner.device
ret = cls(
forward_mode=batch.forward_mode,

View File

@@ -138,6 +138,7 @@ class ModelRunner:
self.init_attention_backend()
self.init_cuda_graphs()
else:
self.cuda_graph_runner = None
self.init_attention_backend()
def init_torch_distributed(self):
@@ -146,6 +147,11 @@ class ModelRunner:
if self.device == "cuda":
torch.cuda.set_device(self.gpu_id)
backend = "nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif self.device == "xpu":
torch.xpu.set_device(self.gpu_id)
backend = "gloo"
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)