diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 3f660318e..3f6d19d21 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -74,6 +74,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_BLOCK_NONZERO_RANK_CHILDREN` | Control blocking of non-zero rank children processes | `1` | | `SGL_IS_FIRST_RANK_ON_NODE` | Indicates if the current process is the first rank on its node | `"true"` | | `SGLANG_PP_LAYER_PARTITION` | Pipeline parallel layer partition specification | Not set | +| `SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS` | Set one visible device per process for distributed computing | `false` | ## Testing & Debugging (Internal/CI) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index d52da3a6e..c8a0c222a 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -61,6 +61,7 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.entrypoints.engine import _set_envs_and_config +from sglang.srt.environ import envs from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import Scheduler @@ -75,6 +76,7 @@ from sglang.srt.utils import ( is_cuda_alike, is_xpu, kill_process_tree, + maybe_reindex_device_id, require_mlp_sync, require_mlp_tp_gather, set_gpu_proc_affinity, @@ -159,7 +161,7 @@ class BenchArgs: ) -def load_model(server_args, port_args, tp_rank): +def load_model(server_args, port_args, gpu_id, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) @@ -168,7 +170,7 @@ def load_model(server_args, port_args, tp_rank): model_runner = ModelRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, - gpu_id=tp_rank, + gpu_id=gpu_id, tp_rank=tp_rank, tp_size=server_args.tp_size, moe_ep_rank=moe_ep_rank, @@ -350,6 +352,7 @@ def correctness_test( server_args, port_args, bench_args, + gpu_id, tp_rank, ): # Configure the logger @@ -357,7 +360,7 @@ def correctness_test( rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, port_args, tp_rank) + model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank) # Prepare inputs custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print) @@ -517,6 +520,7 @@ def latency_test( server_args, port_args, bench_args, + gpu_id, tp_rank, ): initialize_moe_config(server_args) @@ -532,7 +536,7 @@ def latency_test( rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, port_args, tp_rank) + model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank) # Prepare inputs for warm up reqs = prepare_synthetic_inputs_for_latency_test( @@ -634,21 +638,23 @@ def main(server_args, bench_args): port_args = PortArgs.init_new(server_args) if server_args.tp_size == 1: - work_func(server_args, port_args, bench_args, 0) + work_func(server_args, port_args, bench_args, 0, 0) else: workers = [] for tp_rank in range(server_args.tp_size): - proc = multiprocessing.Process( - target=work_func, - args=( - server_args, - port_args, - bench_args, - tp_rank, - ), - ) - proc.start() - workers.append(proc) + with maybe_reindex_device_id(tp_rank) as gpu_id: + proc = multiprocessing.Process( + target=work_func, + args=( + server_args, + port_args, + bench_args, + gpu_id, + tp_rank, + ), + ) + proc.start() + workers.append(proc) for proc in workers: proc.join() diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 7e18d06db..013b3f785 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -39,6 +39,7 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from sglang.srt.environ import envs from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -56,8 +57,6 @@ _is_npu = is_npu() _is_cpu = is_cpu() _supports_custom_op = supports_custom_op() -IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS") - TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) @@ -277,11 +276,13 @@ class GroupCoordinator: assert self.cpu_group is not None assert self.device_group is not None - device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank if is_cuda_alike(): + device_id = ( + 0 if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() else local_rank + ) self.device = torch.device(f"cuda:{device_id}") elif _is_npu: - self.device = torch.device(f"npu:{device_id}") + self.device = torch.device(f"npu:{local_rank}") else: self.device = torch.device("cpu") self.device_module = torch.get_device_module(self.device) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 61e9de1fe..7107a611e 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -75,6 +75,7 @@ from sglang.srt.utils import ( is_cuda, kill_process_tree, launch_dummy_health_check_server, + maybe_reindex_device_id, prepare_model_and_tokenizer, set_prometheus_multiproc_dir, set_ulimit, @@ -782,22 +783,24 @@ def _launch_subprocesses( + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) - proc = mp.Process( - target=run_scheduler_process, - args=( - server_args, - port_args, - gpu_id, - tp_rank, - moe_ep_rank, - pp_rank, - None, - writer, - ), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() + with maybe_reindex_device_id(gpu_id) as gpu_id: + proc = mp.Process( + target=run_scheduler_process, + args=( + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + None, + writer, + ), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) else: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 113e49f6c..acc8b0e68 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -142,6 +142,7 @@ class Envs: # Model Parallel SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True) + SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS = EnvBool(False) # Constrained Decoding SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 56a87516d..cbf14149e 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -46,6 +46,7 @@ from sglang.srt.utils import ( configure_logger, get_zmq_socket, kill_itself_when_parent_died, + maybe_reindex_device_id, ) from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback @@ -139,6 +140,9 @@ class DataParallelController: # Load balance budget self.dp_budget = DPBudget() + # To protect changing env vars to set CUDA_VISIBLE_DEVICES. + self.env_lock = threading.Lock() + # Launch data parallel workers self.scheduler_procs = [] self.workers: List[zmq.Socket] = [None] * server_args.dp_size @@ -399,21 +403,22 @@ class DataParallelController: + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) - proc = mp.Process( - target=run_scheduler_process, - args=( - server_args, - rank_port_args, - gpu_id, - tp_rank, - moe_ep_rank, - pp_rank, - dp_rank, - writer, - ), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() + with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id: + proc = mp.Process( + target=run_scheduler_process, + args=( + server_args, + rank_port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + writer, + ), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() self.scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index e8a4256c9..e898b8923 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -88,6 +88,7 @@ from torch.profiler import ProfilerActivity, profile, record_function from torch.utils._contextlib import _DecoratorContextManager from typing_extensions import Literal +from sglang.srt.environ import envs from sglang.srt.metrics.func_timer import enable_func_timer logger = logging.getLogger(__name__) @@ -3273,7 +3274,7 @@ def json_list_type(value): @contextmanager def maybe_reindex_device_id(gpu_id: int): - if not is_cuda_alike(): + if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() is False or not is_cuda_alike(): yield gpu_id return