[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
@@ -17,7 +17,8 @@ import gc
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
@@ -514,8 +517,21 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(named_tensors)
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
named_tensors = [
|
||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
||||
for name, tensor in named_tensors
|
||||
]
|
||||
if load_format == "direct":
|
||||
_model_load_weights_direct(self.model, named_tensors)
|
||||
elif load_format is None:
|
||||
self.model.load_weights(named_tensors)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown load_format={load_format}")
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
@@ -836,3 +852,26 @@ class ModelRunner:
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
|
||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(model.named_parameters())
|
||||
for name, tensor in named_tensors:
|
||||
default_weight_loader(params_dict[name], tensor)
|
||||
|
||||
|
||||
def _unwrap_tensor(tensor, tp_rank):
|
||||
if isinstance(tensor, LocalSerializedTensor):
|
||||
return tensor.get(tp_rank)
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalSerializedTensor:
|
||||
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
||||
The i-th element in the list corresponds to i-th rank's GPU."""
|
||||
|
||||
values: List[bytes]
|
||||
|
||||
def get(self, rank: int):
|
||||
return MultiprocessingSerializer.deserialize(self.values[rank])
|
||||
|
||||
Reference in New Issue
Block a user