diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 7404c7e4e..724051eab 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,7 +1,13 @@ +import argparse +import dataclasses + import sglang as sgl +from sglang.srt.server_args import ServerArgs -def main(): +def main( + server_args: ServerArgs, +): # Sample prompts. prompts = [ "Hello, my name is", @@ -13,7 +19,7 @@ def main(): sampling_params = {"temperature": 0.8, "top_p": 0.95} # Create an LLM. - llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + llm = sgl.Engine(**dataclasses.asdict(server_args)) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -25,4 +31,8 @@ def main(): # The __main__ condition is necessary here because we use "spawn" to create subprocesses # Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + main(server_args) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index d7db6036c..b0dfda3e8 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) sampled_index = torch.multinomial(probs_sort, num_samples=1) + # int32 range is enough to represent the token ids + probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3714f19b6..fd4edade9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -993,7 +993,7 @@ class Scheduler: self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): @@ -1055,7 +1055,7 @@ class Scheduler: if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model @@ -1130,7 +1130,7 @@ class Scheduler: if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() self.stream_output(batch.reqs) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index e4e20ad8f..6a453d2ad 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_compiler_backend from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def resolve_future_token_ids(input_ids, future_token_ids_map): input_ids[:] = torch.where( input_ids < 0, @@ -73,7 +74,7 @@ class TpModelWorkerClient: # Launch threads self.input_queue = Queue() self.output_queue = Queue() - self.forward_stream = torch.cuda.Stream() + self.forward_stream = torch.get_device_module(self.device).Stream() self.forward_thread = threading.Thread( target=self.forward_thread_func, ) @@ -97,7 +98,7 @@ class TpModelWorkerClient: def forward_thread_func(self): try: - with torch.cuda.stream(self.forward_stream): + with torch.get_device_module(self.device).stream(self.forward_stream): self.forward_thread_func_() except Exception: traceback = get_exception_traceback() @@ -122,7 +123,7 @@ class TpModelWorkerClient: # Create event self.launch_done = threading.Event() - copy_done = torch.cuda.Event() + copy_done = torch.get_device_module(self.device).Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -190,7 +191,7 @@ class TpModelWorkerClient: ) # A cuda stream sync here to avoid the cuda illegal memory access error. - torch.cuda.current_stream().synchronize() + torch.get_device_module(self.device).current_stream().synchronize() # Push a new batch to the queue self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b028309c7..646e71749 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,6 +27,7 @@ from typing import List, Tuple, Union import torch from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import get_compiler_backend logger = logging.getLogger(__name__) @@ -129,6 +130,9 @@ class BaseTokenToKVPool: return select_index.to(self.device, non_blocking=True) def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + if self.is_not_in_free_group: self.free_slots = torch.concat((self.free_slots, free_index.cpu())) else: @@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool): # This compiled version is slower in the unit test # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_1[loc] = src_1.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index a758e4f56..83ac3d867 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import get_compiler_backend, set_weight_attrs -@torch.compile +@torch.compile(backend=get_compiler_backend()) def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3a0a99af9..c2e75a642 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -25,6 +25,7 @@ import torch from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, + get_hpu_memory_capacity, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, @@ -158,6 +159,8 @@ class ServerArgs: gpu_mem = get_amdgpu_memory_capacity() elif torch.cuda.is_available(): gpu_mem = get_nvgpu_memory_capacity() + elif self.device == "hpu": + gpu_mem = get_hpu_memory_capacity() else: # GPU memory is not known yet or no GPU is available. gpu_mem = None @@ -194,6 +197,10 @@ class ServerArgs: self.cuda_graph_max_bs = 160 # Choose kernel backends + if self.device == "hpu": + self.attention_backend = "torch_native" + self.sampling_backend = "pytorch" + if self.attention_backend is None: self.attention_backend = ( "flashinfer" if is_flashinfer_available() else "triton" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 04372bac1..5c310136a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory free_gpu_memory = total_gpu_memory - used_memory + elif device == "hpu": + num_gpus = torch.hpu.device_count() + assert gpu_id < num_gpus + + if torch.hpu.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ", + "which may cause useless memory allocation for torch HPU context.", + ) + + free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info() + if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) @@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity(): ) +def get_hpu_memory_capacity(): + try: + # Run hl-smi and capture the output + result = subprocess.run( + ["hl-smi --query | grep 'Total'"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"hl-smi error: {result.stderr.strip()}") + + # Parse the output to extract memory values in MiB + memory_values = [ + float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n") + ] + + if not memory_values: + raise ValueError("No GPU memory values found.") + + # Return the minimum memory value + return min(memory_values) + + except FileNotFoundError: + raise RuntimeError( + "hl-smi not found. Ensure Habana drivers are installed and accessible." + ) + + # Copy from pytorch and OpenRLHF to allow creating multiple main groups. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py @@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: return major, minor +def get_compiler_backend() -> str: + if hasattr(torch, "hpu") and torch.hpu.is_available(): + return "hpu_backend" + + return "inductor" + + sglang_lib = Library("sglang", "FRAGMENT") # noqa