diff --git a/python/pyproject.toml b/python/pyproject.toml index 2a0fd138f..473e4b6b5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,16 +20,25 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", - "packaging", "pillow", "psutil", "pydantic", "python-multipart", - "torch", "torchao", "uvicorn", "uvloop", "zmq", - "vllm==0.5.5", "outlines>=0.0.44", "modelscope"] +runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", + "packaging", "pillow", "psutil", "pydantic", "python-multipart", + "torchao", "uvicorn", "uvloop", "zmq", + "outlines>=0.0.44", "modelscope"] +torch = ["torch"] +# xpu is not enabled in public vllm and torch whl, +# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm +vllm = ["vllm==0.5.5"] +srt = ["sglang[runtime_common]", "torch", "vllm"] +srt_xpu = ["sglang[runtime_common]"] + openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] +all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] +dev_xpu = ["sglang[all_xpu]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 9540e2266..6194ca1d1 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -288,8 +288,15 @@ def correctness_test( rank_print(tokenizer.decode(output_ids[i]), "\n") +def synchronize(device): + if device == "cuda": + torch.cuda.synchronize() + elif device == "xpu": + torch.xpu.synchronize() + + def latency_test_run_once( - run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len + run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device ): max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) if batch_size > max_batch_size: @@ -312,10 +319,10 @@ def latency_test_run_once( tot_latency = 0 # Prefill - torch.cuda.synchronize() + synchronize(device) tic = time.time() next_token_ids, _, batch = extend(reqs, model_runner) - torch.cuda.synchronize() + synchronize(device) prefill_latency = time.time() - tic tot_latency += prefill_latency throughput = input_len * batch_size / prefill_latency @@ -328,10 +335,10 @@ def latency_test_run_once( # Decode decode_latencies = [] for i in range(output_len - 1): - torch.cuda.synchronize() + synchronize(device) tic = time.time() next_token_ids, _ = decode(next_token_ids, batch, model_runner) - torch.cuda.synchronize() + synchronize(device) latency = time.time() - tic tot_latency += latency throughput = batch_size / latency @@ -387,6 +394,7 @@ def latency_test( bench_args.batch_size[0], bench_args.input_len[0], 8, # shorter decoding to speed up the warmup + server_args.device, ) rank_print("Benchmark ...") @@ -397,7 +405,14 @@ def latency_test( ): reqs = prepare_synthetic_inputs_for_latency_test(bs, il) ret = latency_test_run_once( - bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol + bench_args.run_name, + model_runner, + rank_print, + reqs, + bs, + il, + ol, + server_args.device, ) if ret is not None: result_list.append(ret) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 82b9596bf..14c849a9f 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_max_seq_len = model_runner.model_config.context_len + self.device = model_runner.device + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend): attn_logits = torch.empty( (self.num_head, total_num_tokens), dtype=self.reduce_dtype, - device="cuda", + device=self.device, ) max_seq_len = torch.max(forward_batch.seq_lens).item() @@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len self.cuda_graph_start_loc = torch.zeros( - (max_bs,), dtype=torch.int32, device="cuda" + (max_bs,), dtype=torch.int32, device=self.device ) self.cuda_graph_attn_logits = torch.empty( ( diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 919ef3d2e..52a72d7fe 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) -CUDA_CAPABILITY = torch.cuda.get_device_capability() +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() @triton.jit @@ -286,12 +288,12 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if CUDA_CAPABILITY[0] >= 9: + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: if Lq <= 256: BLOCK_M, BLOCK_N = (128, 64) else: BLOCK_M, BLOCK_N = (32, 64) - elif CUDA_CAPABILITY[0] >= 8: + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: if Lq <= 128: BLOCK_M, BLOCK_N = (128, 128) elif Lq <= 256: diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index e19e73ec1..79ad35e44 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -24,7 +24,9 @@ import torch import triton import triton.language as tl -CUDA_CAPABILITY = torch.cuda.get_device_capability() +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() @triton.jit @@ -145,7 +147,7 @@ def _fwd_kernel( def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - if CUDA_CAPABILITY[0] >= 8: + if is_cuda_available and CUDA_CAPABILITY[0] >= 8: BLOCK = 128 else: BLOCK = 64 diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d76b981d5..be6a72afd 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -118,7 +118,7 @@ class ForwardBatch: batch: ModelWorkerBatch, model_runner: ModelRunner, ): - device = "cuda" + device = model_runner.device ret = cls( forward_mode=batch.forward_mode, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5f0675de5..5ca49cc17 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -138,6 +138,7 @@ class ModelRunner: self.init_attention_backend() self.init_cuda_graphs() else: + self.cuda_graph_runner = None self.init_attention_backend() def init_torch_distributed(self): @@ -146,6 +147,11 @@ class ModelRunner: if self.device == "cuda": torch.cuda.set_device(self.gpu_id) backend = "nccl" + # ToDO(liangan1):Just use gloo to bypass the initilization fail + # Need to use xccl for xpu backend in the future + elif self.device == "xpu": + torch.xpu.set_device(self.gpu_id) + backend = "gloo" if not self.server_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4b70b393e..a07bc04be 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -242,7 +242,7 @@ class ServerArgs: "--device", type=str, default="cuda", - choices=["cuda"], + choices=["cuda", "xpu"], help="The device type.", ) parser.add_argument(