[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
@@ -121,7 +121,7 @@ class DataParallelController:
|
||||
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
||||
)
|
||||
threads.append(thread)
|
||||
base_gpu_id += server_args.tp_size
|
||||
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
||||
|
||||
# Free all sockets before starting the threads to launch TP workers
|
||||
for sock in sockets:
|
||||
@@ -177,7 +177,11 @@ class DataParallelController:
|
||||
rank_port_args.nccl_port = port_args.nccl_port
|
||||
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ base_gpu_id
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
|
||||
@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
|
||||
load_format: Optional[str]
|
||||
flush_cache: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1760,8 +1760,9 @@ class Scheduler:
|
||||
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
||||
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
if recv_req.flush_cache:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||
|
||||
@@ -205,7 +205,10 @@ class TpModelWorker:
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
|
||||
named_tensors=MultiprocessingSerializer.deserialize(
|
||||
recv_req.serialized_named_tensors
|
||||
),
|
||||
load_format=recv_req.load_format,
|
||||
)
|
||||
return success, message
|
||||
|
||||
|
||||
Reference in New Issue
Block a user