diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f4d41feef..92f998b2c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -423,6 +423,9 @@ class ScheduleBatch: # Stream has_stream: bool = False + # device + device: str = "cuda" + # Has regex has_regex: bool = False @@ -439,6 +442,7 @@ class ScheduleBatch: tree_cache=tree_cache, return_logprob=return_logprob, has_stream=has_stream, + device=req_to_token_pool.device, has_regex=has_regex, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 698d35d99..c2f2368e4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -81,10 +81,11 @@ class ModelRunner: # Parse args self.model_config = model_config self.mem_fraction_static = mem_fraction_static + self.device = server_args.device self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size - self.nccl_port = nccl_port + self.dist_port = nccl_port self.server_args = server_args self.is_multimodal_model = is_multimodal_model( self.model_config.hf_config.architectures @@ -132,39 +133,45 @@ class ModelRunner: server_args.max_running_requests, server_args.max_total_tokens, ) - self.init_cublas() - self.init_attention_backend() - self.init_cuda_graphs() + if self.device == "cuda": + self.init_cublas() + self.init_attention_backend() + self.init_cuda_graphs() + else: + self.init_attention_backend() def init_torch_distributed(self): + logger.info("Init torch distributed begin.") # Init torch distributed - torch.cuda.set_device(self.gpu_id) - logger.info("Init nccl begin.") + if self.device == "cuda": + torch.cuda.set_device(self.gpu_id) + backend = "nccl" if not self.server_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) - if self.server_args.dist_init_addr: - nccl_init_method = f"tcp://{self.server_args.dist_init_addr}" + dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: - nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" + dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) init_distributed_environment( - backend="nccl", + backend=backend, world_size=self.tp_size, rank=self.tp_rank, local_rank=self.gpu_id, - distributed_init_method=nccl_init_method, + distributed_init_method=dist_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) min_per_gpu_memory = get_available_gpu_memory( - self.gpu_id, distributed=self.tp_size > 1 + self.device, self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, # so we disable padding in cuda graph. - if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)): + if self.device == "cuda" and not all( + in_the_same_node_as(self.tp_group.cpu_group, source_rank=0) + ): self.server_args.disable_cuda_graph_padding = True logger.info( "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." @@ -172,7 +179,7 @@ class ModelRunner: # Check memory for tensor parallelism if self.tp_size > 1: - local_gpu_memory = get_available_gpu_memory(self.gpu_id) + local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) if min_per_gpu_memory < local_gpu_memory * 0.9: raise ValueError( "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." @@ -182,23 +189,22 @@ class ModelRunner: def load_model(self): logger.info( - f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) # This can reduce thread conflicts and speed up weight loading. torch.set_num_threads(1) - - if torch.cuda.get_device_capability()[0] < 8: - logger.info( - "Compute capability below sm80. Use float16 due to lack of bfloat16 support." - ) - self.server_args.dtype = "float16" - if torch.cuda.get_device_capability()[1] < 5: - raise RuntimeError("SGLang only supports sm75 and above.") + if self.device == "cuda": + if torch.cuda.get_device_capability()[0] < 8: + logger.info( + "Compute capability below sm80. Use float16 due to lack of bfloat16 support." + ) + self.server_args.dtype = "float16" + if torch.cuda.get_device_capability()[1] < 5: + raise RuntimeError("SGLang only supports sm75 and above.") # Prepare the vllm model config monkey_patch_vllm_dummy_weight_loader() - self.device_config = DeviceConfig() self.load_config = LoadConfig(load_format=self.server_args.load_format) self.vllm_model_config = VllmModelConfig( model=self.server_args.model_path, @@ -220,7 +226,7 @@ class ModelRunner: self.model = get_model( model_config=self.vllm_model_config, load_config=self.load_config, - device_config=self.device_config, + device_config=DeviceConfig(self.device), parallel_config=None, scheduler_config=None, lora_config=None, @@ -240,7 +246,7 @@ class ModelRunner: f"Load weight end. " f"type={type(self.model).__name__}, " f"dtype={self.dtype}, " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) def update_weights(self, model_path: str, load_format: str): @@ -254,10 +260,10 @@ class ModelRunner: logger.info( f"Update weights begin. " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - target_device = torch.device(self.device_config.device) + target_device = torch.device(self.device) try: # TODO: Use a better method to check this @@ -343,7 +349,7 @@ class ModelRunner: def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( - self.gpu_id, distributed=self.tp_size > 1 + self.device, self.gpu_id, distributed=self.tp_size > 1 ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -409,11 +415,10 @@ class ModelRunner: 4096, ) - device = "cuda" self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, - device=device, + device=self.device, ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -425,7 +430,7 @@ class ModelRunner: kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, - device=device, + device=self.device, ) else: self.token_to_kv_pool = MHATokenToKVPool( @@ -434,11 +439,11 @@ class ModelRunner: head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, - device=device, + device=self.device, ) logger.info( f"Memory pool end. " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) def init_cublas(self): diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index de781acb3..c89608c36 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -37,6 +37,9 @@ class SamplingBatchInfo: linear_penalties: torch.Tensor = None scaling_penalties: torch.Tensor = None + # Device + device: str = "cuda" + @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs @@ -62,6 +65,7 @@ class SamplingBatchInfo: min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), vocab_size=vocab_size, + device=batch.input_ids.device, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -75,7 +79,7 @@ class SamplingBatchInfo: ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( vocab_size=vocab_size, batch=batch, - device="cuda", + device=batch.input_ids.device, Penalizers={ penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedMinNewTokensPenalizer, @@ -107,7 +111,7 @@ class SamplingBatchInfo: self.linear_penalties = torch.zeros( (bs, self.vocab_size), dtype=torch.float32, - device="cuda", + device=self.device, ) self.linear_penalties = penalizer.apply(self.linear_penalties) @@ -119,7 +123,10 @@ class SamplingBatchInfo: if has_regex: self.vocab_mask = torch.zeros( - len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" + len(self.temperatures), + self.vocab_size, + dtype=torch.bool, + device=self.device, ) for i, regex_fsm in enumerate(self.regex_fsms): if regex_fsm is not None: @@ -144,7 +151,12 @@ class SamplingBatchInfo: @staticmethod def merge_bias_tensor( - lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0 + lhs: torch.Tensor, + rhs: torch.Tensor, + bs1: int, + bs2: int, + device: str, + default: int = 0, ): # bias tensor can be None if lhs is not None or rhs is not None: @@ -155,9 +167,9 @@ class SamplingBatchInfo: shape, dtype = rhs.shape[1:], rhs.dtype with torch.dtype(dtype): if lhs is None: - lhs = torch.empty((bs1, *shape), device="cuda").fill_(default) + lhs = torch.empty((bs1, *shape), device=device).fill_(default) if rhs is None: - rhs = torch.empty((bs2, *shape), device="cuda").fill_(default) + rhs = torch.empty((bs2, *shape), device=device).fill_(default) return torch.cat([lhs, rhs]) return None @@ -176,5 +188,5 @@ class SamplingBatchInfo: setattr(self, item, torch.concat([self_val, other_val])) self.logit_bias = SamplingBatchInfo.merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other) + self.logit_bias, other.logit_bias, len(self), len(other), self.device ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c6e4b2406..82e588d1c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -36,6 +36,7 @@ class ServerArgs: skip_tokenizer_init: bool = False load_format: str = "auto" dtype: str = "auto" + device: str = "cuda" kv_cache_dtype: str = "auto" trust_remote_code: bool = True context_length: Optional[int] = None @@ -237,6 +238,13 @@ class ServerArgs: '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda"], + help="The device type.", + ) parser.add_argument( "--kv-cache-dtype", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ff8b4575c..f0ac21fb1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0): return wrapper -def get_available_gpu_memory(gpu_id, distributed=False): +def get_available_gpu_memory(device, gpu_id, distributed=False): """ Get available memory for cuda:gpu_id device. When distributed is True, the available memory is the minimum available memory of all GPUs. """ - num_gpus = torch.cuda.device_count() - assert gpu_id < num_gpus + if device == "cuda": + num_gpus = torch.cuda.device_count() + assert gpu_id < num_gpus - if torch.cuda.current_device() != gpu_id: - print( - f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", - "which may cause useless memory allocation for torch CUDA context.", - ) + if torch.cuda.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", + "which may cause useless memory allocation for torch CUDA context.", + ) - torch.cuda.empty_cache() - free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) + torch.cuda.empty_cache() + free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) + + elif device == "xpu": + num_gpus = torch.xpu.device_count() + assert gpu_id < num_gpus + + if torch.xpu.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ", + "which may cause useless memory allocation for torch XPU context.", + ) + torch.xpu.empty_cache() + used_memory = torch.xpu.memory_allocated() + total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory + free_gpu_memory = total_gpu_memory - used_memory if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( - torch.device("cuda", gpu_id) + torch.device(device, gpu_id) ) torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) free_gpu_memory = tensor.item()