Online weight updates from torch.distributed (#2279)
This commit is contained in:
@@ -20,10 +20,13 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_model_config,
|
||||
@@ -404,6 +408,86 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address,
|
||||
master_port,
|
||||
rank_offset,
|
||||
world_size,
|
||||
group_name,
|
||||
backend="nccl",
|
||||
):
|
||||
"""Initialize the Torch process group for model parameter updates.
|
||||
|
||||
`_model_update_group` is used in the RLHF workflow, where rank
|
||||
0 is the actor model in the training engine, and the other ranks are
|
||||
the inference engine, which is used for rollout.
|
||||
|
||||
In the RLHF workflow, the training engine updates the model
|
||||
weights/parameters online, and broadcasts them to the inference
|
||||
engine through the `_model_update_group` process group.
|
||||
"""
|
||||
assert (
|
||||
torch.distributed.is_initialized()
|
||||
), "Default torch process group must be initialized"
|
||||
assert group_name != "", "Group name cannot be empty"
|
||||
|
||||
rank = rank_offset + self.tp_rank
|
||||
|
||||
logger.info(
|
||||
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
)
|
||||
|
||||
try:
|
||||
self._model_update_group = init_custom_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://{master_address}:{master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_name=group_name,
|
||||
)
|
||||
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
||||
return True, "Succeeded to initialize custom process group."
|
||||
except Exception as e:
|
||||
message = f"Failed to initialize custom process group: {e}."
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""
|
||||
Update specific parameter in the model weights online
|
||||
through `_model_update_group` process group.
|
||||
|
||||
Args:
|
||||
name: the name of the parameter to be updated.
|
||||
dtype: the data type of the parameter to be updated.
|
||||
shape: the shape of the parameter to be updated.
|
||||
"""
|
||||
target_dtype = (
|
||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||
)
|
||||
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
||||
|
||||
assert (
|
||||
self._model_update_group is not None
|
||||
), "model update group must be initialized"
|
||||
|
||||
try:
|
||||
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
||||
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
||||
self.model.load_weights([(name, weights)])
|
||||
return True, f"Succeeded to update parameter {name} online."
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Failed to update parameter online: {e}. "
|
||||
f"The full weights of the ModelRunner are partially updated. "
|
||||
f"Please discard the whole weights."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user