[Feature] SPMD for SGLang + Verl (#3852)

This commit is contained in:
fzyzcjy
2025-03-01 01:53:10 +08:00
committed by GitHub
parent bac414ab53
commit e3e0bc50a9
19 changed files with 890 additions and 202 deletions

View File

@@ -17,7 +17,8 @@ import gc
import json
import logging
import time
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
@@ -514,8 +517,21 @@ class ModelRunner:
logger.error(error_msg)
return False, error_msg
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
self.model.load_weights(named_tensors)
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None,
):
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
for name, tensor in named_tensors
]
if load_format == "direct":
_model_load_weights_direct(self.model, named_tensors)
elif load_format is None:
self.model.load_weights(named_tensors)
else:
raise NotImplementedError(f"Unknown load_format={load_format}")
return True, "Success"
def get_weights_by_name(
@@ -836,3 +852,26 @@ class ModelRunner:
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
for name, tensor in named_tensors:
default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank):
if isinstance(tensor, LocalSerializedTensor):
return tensor.get(tp_rank)
return tensor
@dataclass
class LocalSerializedTensor:
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
The i-th element in the list corresponds to i-th rank's GPU."""
values: List[bytes]
def get(self, rank: int):
return MultiprocessingSerializer.deserialize(self.values[rank])