Patch PyTorch's bug that cross-process tensor transfer will lead to wrong device (#4565)
This commit is contained in:
71
python/sglang/srt/patch_torch.py
Normal file
71
python/sglang/srt/patch_torch.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import reductions
|
||||
|
||||
|
||||
def monkey_patch_torch_reductions():
|
||||
"""Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
|
||||
|
||||
if hasattr(reductions, "_reduce_tensor_original"):
|
||||
return
|
||||
|
||||
reductions._reduce_tensor_original = reductions.reduce_tensor
|
||||
reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor
|
||||
|
||||
reductions.reduce_tensor = _reduce_tensor_modified
|
||||
reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified
|
||||
|
||||
reductions.init_reductions()
|
||||
|
||||
|
||||
# The signature has not been changed for years, and we will not need this when the next version is released,
|
||||
# so it looks safe to use a constant.
|
||||
_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6
|
||||
|
||||
|
||||
def _reduce_tensor_modified(*args, **kwargs):
|
||||
output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs)
|
||||
output_args = _modify_tuple(
|
||||
output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid
|
||||
)
|
||||
return output_fn, output_args
|
||||
|
||||
|
||||
def _rebuild_cuda_tensor_modified(*args):
|
||||
args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid)
|
||||
return reductions._rebuild_cuda_tensor_original(*args)
|
||||
|
||||
|
||||
def _device_to_uuid(device: int) -> str:
|
||||
return str(torch.cuda.get_device_properties(device).uuid)
|
||||
|
||||
|
||||
def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
|
||||
if isinstance(device_maybe_uuid, int):
|
||||
return device_maybe_uuid
|
||||
|
||||
if isinstance(device_maybe_uuid, str):
|
||||
for device in range(torch.cuda.device_count()):
|
||||
if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid:
|
||||
return device
|
||||
raise Exception("Invalid device_uuid=" + device_maybe_uuid)
|
||||
|
||||
raise Exception(f"Unknown type: {device_maybe_uuid=}")
|
||||
|
||||
|
||||
def _modify_tuple(t, index: int, modifier: Callable):
|
||||
return *t[:index], modifier(t[index]), *t[index + 1 :]
|
||||
Reference in New Issue
Block a user