Patch PyTorch's bug that cross-process tensor transfer will lead to wrong device (#4565)
This commit is contained in:
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
|
||||
)
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -1082,8 +1083,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
||||
|
||||
def _unwrap_tensor(tensor, tp_rank):
|
||||
if isinstance(tensor, LocalSerializedTensor):
|
||||
return tensor.get(tp_rank)
|
||||
return tensor
|
||||
monkey_patch_torch_reductions()
|
||||
tensor = tensor.get(tp_rank)
|
||||
return tensor.to(torch.cuda.current_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user