Patch PyTorch's bug that cross-process tensor transfer will lead to wrong device (#4565)

This commit is contained in:
fzyzcjy
2025-03-27 15:22:33 +08:00
committed by GitHub
parent 6f5cc5eb05
commit 92bb49a7f9
5 changed files with 211 additions and 2 deletions

View File

@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -1082,8 +1083,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
def _unwrap_tensor(tensor, tp_rank):
if isinstance(tensor, LocalSerializedTensor):
return tensor.get(tp_rank)
return tensor
monkey_patch_torch_reductions()
tensor = tensor.get(tp_rank)
return tensor.to(torch.cuda.current_device())
@dataclass