diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index 76927a745..13f60451e 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -19,6 +19,7 @@ import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor +from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server import Engine from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj @@ -30,6 +31,7 @@ class VerlEngine: nnodes: int = 1, **kwargs, ): + monkey_patch_torch_reductions() self._device_mesh_cpu = device_mesh_cpu self._tp_rank = device_mesh_cpu.get_local_rank() self._tp_size = device_mesh_cpu.size() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e6d74998e..716b61e22 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import ( ) from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -1082,8 +1083,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso def _unwrap_tensor(tensor, tp_rank): if isinstance(tensor, LocalSerializedTensor): - return tensor.get(tp_rank) - return tensor + monkey_patch_torch_reductions() + tensor = tensor.get(tp_rank) + return tensor.to(torch.cuda.current_device()) @dataclass diff --git a/python/sglang/srt/patch_torch.py b/python/sglang/srt/patch_torch.py new file mode 100644 index 000000000..32034b704 --- /dev/null +++ b/python/sglang/srt/patch_torch.py @@ -0,0 +1,71 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Callable, Union + +import torch +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple( + output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid + ) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index fa1a7c376..6095ecf2a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -46,6 +46,7 @@ suites = { TestFile("test_openai_server.py", 124), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), + TestFile("test_patch_torch.py", 60), TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 167), TestFile("test_reasoning_content.py", 89), diff --git a/test/srt/test_patch_torch.py b/test/srt/test_patch_torch.py new file mode 100644 index 000000000..a2c04509e --- /dev/null +++ b/test/srt/test_patch_torch.py @@ -0,0 +1,133 @@ +import os +import traceback +import unittest +from typing import Dict, List + +import torch +import torch.multiprocessing as mp + +from sglang.srt.patch_torch import monkey_patch_torch_reductions + + +class TestReleaseMemoryOccupation(unittest.TestCase): + def test_monkey_patch_torch_reductions(self): + mp.set_start_method("spawn", force=True) + + for enable_patch in [False, True]: + for params in [ + # Same visible devices + dict( + sender_info=dict( + visible_devices=[0, 1], + tensor_device=1, + ), + receiver_info=dict( + visible_devices=[0, 1], + tensor_device=1, + ), + ), + # Different visible devices + dict( + sender_info=dict( + visible_devices=[0, 1], + tensor_device=1, + ), + receiver_info=dict( + visible_devices=[1, 0], + # If enable patch, this should be fixed, and cuda:1 becomes cuda:0 + tensor_device=0 if enable_patch else 1, + ), + ), + ]: + with self.subTest(f"{enable_patch=} {params=}"): + self._test_monkey_patch_torch_reductions_core( + enable_patch=enable_patch, **params + ) + + def _test_monkey_patch_torch_reductions_core( + self, + sender_info: Dict, + receiver_info: Dict, + enable_patch: bool, + ): + print( + f'test_monkey_patch_torch_reductions_core {os.environ.get("CUDA_VISIBLE_DEVICES")=}' + ) + cuda_visible_devices_list: List[int] = [ + int(x) + for x in os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split( + "," + ) + ] + + processes = [] + output_reader, output_writer = mp.Pipe(duplex=False) + queue = mp.Queue() + for role, info in [ + ("sender", sender_info), + ("receiver", receiver_info), + ]: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(cuda_visible_devices_list[device]) + for device in info["visible_devices"] + ) + p = mp.Process( + target=_run_subprocess, + kwargs=dict( + role=role, + queue=queue, + output_writer=output_writer, + tensor_device=info["tensor_device"], + enable_patch=enable_patch, + ), + ) + p.start() + processes.append(p) + + for _ in range(len(processes)): + self.assertTrue( + output_reader.recv(), f"Subprocess has error, please see logs above." + ) + + for p in processes: + p.join() + + +def _run_subprocess( + role: str, queue: mp.Queue, output_writer, tensor_device: int, enable_patch: bool +): + print( + f'subprocess[{role}] start {os.environ.get("CUDA_VISIBLE_DEVICES")=}', + flush=True, + ) + + if enable_patch: + print(f"subprocess[{role}] execute monkey_patch_torch_reductions", flush=True) + monkey_patch_torch_reductions() + + try: + if role == "sender": + tensor = torch.tensor([1.0, 2.0], device=f"cuda:{tensor_device}") + print(f"sender queue.put {tensor=} {tensor.device=}") + queue.put(tensor) + assert queue.get() == "done" + elif role == "receiver": + tensor = queue.get() + print(f"receiver queue.get {tensor=} {tensor.device=}") + assert str(tensor.device) == f"cuda:{tensor_device}" + queue.put("done") + else: + raise NotImplementedError + + execution_ok = True + except Exception as e: + print(f"subprocess[{role}] has error: {e}", flush=True) + traceback.print_exc() + execution_ok = False + + output_writer.send(execution_ok) + output_writer.close() + + +if __name__ == "__main__": + unittest.main()