[Feature] SPMD for SGLang + Verl (#3852)

This commit is contained in:
fzyzcjy
2025-03-01 01:53:10 +08:00
committed by GitHub
parent bac414ab53
commit e3e0bc50a9
19 changed files with 890 additions and 202 deletions

View File

@@ -121,7 +121,7 @@ class DataParallelController:
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
)
threads.append(thread)
base_gpu_id += server_args.tp_size
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
@@ -177,7 +177,11 @@ class DataParallelController:
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),

View File

@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
@dataclass
class UpdateWeightsFromTensorReqInput:
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
load_format: Optional[str]
flush_cache: bool
@dataclass

View File

@@ -1760,8 +1760,9 @@ class Scheduler:
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
if recv_req.flush_cache:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromTensorReqOutput(success, message)

View File

@@ -205,7 +205,10 @@ class TpModelWorker:
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor(
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors
),
load_format=recv_req.load_format,
)
return success, message