[Feature, Hardware] Enable SGLang on XPU GPUs via PyTorch (#1480)
This commit is contained in:
@@ -20,16 +20,25 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||
"torch", "torchao", "uvicorn", "uvloop", "zmq",
|
||||
"vllm==0.5.5", "outlines>=0.0.44", "modelscope"]
|
||||
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||
"torchao", "uvicorn", "uvloop", "zmq",
|
||||
"outlines>=0.0.44", "modelscope"]
|
||||
torch = ["torch"]
|
||||
# xpu is not enabled in public vllm and torch whl,
|
||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||
vllm = ["vllm==0.5.5"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm"]
|
||||
srt_xpu = ["sglang[runtime_common]"]
|
||||
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||
|
||||
@@ -288,8 +288,15 @@ def correctness_test(
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
def synchronize(device):
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
def latency_test_run_once(
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
@@ -312,10 +319,10 @@ def latency_test_run_once(
|
||||
tot_latency = 0
|
||||
|
||||
# Prefill
|
||||
torch.cuda.synchronize()
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
synchronize(device)
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = input_len * batch_size / prefill_latency
|
||||
@@ -328,10 +335,10 @@ def latency_test_run_once(
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
for i in range(output_len - 1):
|
||||
torch.cuda.synchronize()
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
synchronize(device)
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = batch_size / latency
|
||||
@@ -387,6 +394,7 @@ def latency_test(
|
||||
bench_args.batch_size[0],
|
||||
bench_args.input_len[0],
|
||||
8, # shorter decoding to speed up the warmup
|
||||
server_args.device,
|
||||
)
|
||||
rank_print("Benchmark ...")
|
||||
|
||||
@@ -397,7 +405,14 @@ def latency_test(
|
||||
):
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
server_args.device,
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||
|
||||
self.device = model_runner.device
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
attn_logits = torch.empty(
|
||||
(self.num_head, total_num_tokens),
|
||||
dtype=self.reduce_dtype,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
||||
@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||
|
||||
self.cuda_graph_start_loc = torch.zeros(
|
||||
(max_bs,), dtype=torch.int32, device="cuda"
|
||||
(max_bs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.cuda_graph_attn_logits = torch.empty(
|
||||
(
|
||||
|
||||
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -286,12 +288,12 @@ def extend_attention_fwd(
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
if CUDA_CAPABILITY[0] >= 9:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif CUDA_CAPABILITY[0] >= 8:
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
|
||||
@@ -24,7 +24,9 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -145,7 +147,7 @@ def _fwd_kernel(
|
||||
|
||||
|
||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
if CUDA_CAPABILITY[0] >= 8:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
@@ -118,7 +118,7 @@ class ForwardBatch:
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
device = "cuda"
|
||||
device = model_runner.device
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
|
||||
@@ -138,6 +138,7 @@ class ModelRunner:
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
else:
|
||||
self.cuda_graph_runner = None
|
||||
self.init_attention_backend()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
@@ -146,6 +147,11 @@ class ModelRunner:
|
||||
if self.device == "cuda":
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
backend = "nccl"
|
||||
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
||||
# Need to use xccl for xpu backend in the future
|
||||
elif self.device == "xpu":
|
||||
torch.xpu.set_device(self.gpu_id)
|
||||
backend = "gloo"
|
||||
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
|
||||
@@ -242,7 +242,7 @@ class ServerArgs:
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
choices=["cuda", "xpu"],
|
||||
help="The device type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user