[Feature, Hardware] Enable SGLang on XPU GPUs via PyTorch (#1480)
This commit is contained in:
@@ -20,16 +20,25 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||||
"torch", "torchao", "uvicorn", "uvloop", "zmq",
|
"torchao", "uvicorn", "uvloop", "zmq",
|
||||||
"vllm==0.5.5", "outlines>=0.0.44", "modelscope"]
|
"outlines>=0.0.44", "modelscope"]
|
||||||
|
torch = ["torch"]
|
||||||
|
# xpu is not enabled in public vllm and torch whl,
|
||||||
|
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||||
|
vllm = ["vllm==0.5.5"]
|
||||||
|
srt = ["sglang[runtime_common]", "torch", "vllm"]
|
||||||
|
srt_xpu = ["sglang[runtime_common]"]
|
||||||
|
|
||||||
openai = ["openai>=1.0", "tiktoken"]
|
openai = ["openai>=1.0", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0"]
|
anthropic = ["anthropic>=0.20.0"]
|
||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||||
|
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||||
dev = ["sglang[all]", "sglang[test]"]
|
dev = ["sglang[all]", "sglang[test]"]
|
||||||
|
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||||
|
|||||||
@@ -288,8 +288,15 @@ def correctness_test(
|
|||||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize(device):
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif device == "xpu":
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
|
||||||
|
|
||||||
def latency_test_run_once(
|
def latency_test_run_once(
|
||||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
||||||
):
|
):
|
||||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||||
if batch_size > max_batch_size:
|
if batch_size > max_batch_size:
|
||||||
@@ -312,10 +319,10 @@ def latency_test_run_once(
|
|||||||
tot_latency = 0
|
tot_latency = 0
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
torch.cuda.synchronize()
|
synchronize(device)
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||||
torch.cuda.synchronize()
|
synchronize(device)
|
||||||
prefill_latency = time.time() - tic
|
prefill_latency = time.time() - tic
|
||||||
tot_latency += prefill_latency
|
tot_latency += prefill_latency
|
||||||
throughput = input_len * batch_size / prefill_latency
|
throughput = input_len * batch_size / prefill_latency
|
||||||
@@ -328,10 +335,10 @@ def latency_test_run_once(
|
|||||||
# Decode
|
# Decode
|
||||||
decode_latencies = []
|
decode_latencies = []
|
||||||
for i in range(output_len - 1):
|
for i in range(output_len - 1):
|
||||||
torch.cuda.synchronize()
|
synchronize(device)
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||||
torch.cuda.synchronize()
|
synchronize(device)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
tot_latency += latency
|
tot_latency += latency
|
||||||
throughput = batch_size / latency
|
throughput = batch_size / latency
|
||||||
@@ -387,6 +394,7 @@ def latency_test(
|
|||||||
bench_args.batch_size[0],
|
bench_args.batch_size[0],
|
||||||
bench_args.input_len[0],
|
bench_args.input_len[0],
|
||||||
8, # shorter decoding to speed up the warmup
|
8, # shorter decoding to speed up the warmup
|
||||||
|
server_args.device,
|
||||||
)
|
)
|
||||||
rank_print("Benchmark ...")
|
rank_print("Benchmark ...")
|
||||||
|
|
||||||
@@ -397,7 +405,14 @@ def latency_test(
|
|||||||
):
|
):
|
||||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
||||||
ret = latency_test_run_once(
|
ret = latency_test_run_once(
|
||||||
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
|
bench_args.run_name,
|
||||||
|
model_runner,
|
||||||
|
rank_print,
|
||||||
|
reqs,
|
||||||
|
bs,
|
||||||
|
il,
|
||||||
|
ol,
|
||||||
|
server_args.device,
|
||||||
)
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
result_list.append(ret)
|
result_list.append(ret)
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
|
self.device = model_runner.device
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(self.num_head, total_num_tokens),
|
(self.num_head, total_num_tokens),
|
||||||
dtype=self.reduce_dtype,
|
dtype=self.reduce_dtype,
|
||||||
device="cuda",
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
||||||
@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||||
|
|
||||||
self.cuda_graph_start_loc = torch.zeros(
|
self.cuda_graph_start_loc = torch.zeros(
|
||||||
(max_bs,), dtype=torch.int32, device="cuda"
|
(max_bs,), dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
self.cuda_graph_attn_logits = torch.empty(
|
self.cuda_graph_attn_logits = torch.empty(
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|||||||
context_attention_fwd,
|
context_attention_fwd,
|
||||||
)
|
)
|
||||||
|
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
is_cuda_available = torch.cuda.is_available()
|
||||||
|
if is_cuda_available:
|
||||||
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -286,12 +288,12 @@ def extend_attention_fwd(
|
|||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
if CUDA_CAPABILITY[0] >= 9:
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||||
if Lq <= 256:
|
if Lq <= 256:
|
||||||
BLOCK_M, BLOCK_N = (128, 64)
|
BLOCK_M, BLOCK_N = (128, 64)
|
||||||
else:
|
else:
|
||||||
BLOCK_M, BLOCK_N = (32, 64)
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
elif CUDA_CAPABILITY[0] >= 8:
|
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||||
if Lq <= 128:
|
if Lq <= 128:
|
||||||
BLOCK_M, BLOCK_N = (128, 128)
|
BLOCK_M, BLOCK_N = (128, 128)
|
||||||
elif Lq <= 256:
|
elif Lq <= 256:
|
||||||
|
|||||||
@@ -24,7 +24,9 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
is_cuda_available = torch.cuda.is_available()
|
||||||
|
if is_cuda_available:
|
||||||
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -145,7 +147,7 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||||
if CUDA_CAPABILITY[0] >= 8:
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
else:
|
else:
|
||||||
BLOCK = 64
|
BLOCK = 64
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class ForwardBatch:
|
|||||||
batch: ModelWorkerBatch,
|
batch: ModelWorkerBatch,
|
||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
):
|
):
|
||||||
device = "cuda"
|
device = model_runner.device
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
|
|||||||
@@ -138,6 +138,7 @@ class ModelRunner:
|
|||||||
self.init_attention_backend()
|
self.init_attention_backend()
|
||||||
self.init_cuda_graphs()
|
self.init_cuda_graphs()
|
||||||
else:
|
else:
|
||||||
|
self.cuda_graph_runner = None
|
||||||
self.init_attention_backend()
|
self.init_attention_backend()
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
@@ -146,6 +147,11 @@ class ModelRunner:
|
|||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
torch.cuda.set_device(self.gpu_id)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
|
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
||||||
|
# Need to use xccl for xpu backend in the future
|
||||||
|
elif self.device == "xpu":
|
||||||
|
torch.xpu.set_device(self.gpu_id)
|
||||||
|
backend = "gloo"
|
||||||
|
|
||||||
if not self.server_args.enable_p2p_check:
|
if not self.server_args.enable_p2p_check:
|
||||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||||
|
|||||||
@@ -242,7 +242,7 @@ class ServerArgs:
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda"],
|
choices=["cuda", "xpu"],
|
||||||
help="The device type.",
|
help="The device type.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user