Online weight updates from torch.distributed (#2279)

This commit is contained in:
Chayenne
2024-12-01 23:23:18 -08:00
committed by GitHub
parent 28bc60dcab
commit 983bfcf386
12 changed files with 1120 additions and 61 deletions

View File

@@ -20,10 +20,13 @@ import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from typing import Optional, Type
from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
@@ -404,6 +408,86 @@ class ModelRunner:
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
def init_weights_update_group(
self,
master_address,
master_port,
rank_offset,
world_size,
group_name,
backend="nccl",
):
"""Initialize the Torch process group for model parameter updates.
`_model_update_group` is used in the RLHF workflow, where rank
0 is the actor model in the training engine, and the other ranks are
the inference engine, which is used for rollout.
In the RLHF workflow, the training engine updates the model
weights/parameters online, and broadcasts them to the inference
engine through the `_model_update_group` process group.
"""
assert (
torch.distributed.is_initialized()
), "Default torch process group must be initialized"
assert group_name != "", "Group name cannot be empty"
rank = rank_offset + self.tp_rank
logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
try:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
dist.barrier(group=self._model_update_group, device_ids=[rank])
return True, "Succeeded to initialize custom process group."
except Exception as e:
message = f"Failed to initialize custom process group: {e}."
logger.error(message)
return False, message
def update_weights_from_distributed(self, name, dtype, shape):
"""
Update specific parameter in the model weights online
through `_model_update_group` process group.
Args:
name: the name of the parameter to be updated.
dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated.
"""
target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
)
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
assert (
self._model_update_group is not None
), "model update group must be initialized"
try:
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
self.model.load_weights([(name, weights)])
return True, f"Succeeded to update parameter {name} online."
except Exception as e:
error_msg = (
f"Failed to update parameter online: {e}. "
f"The full weights of the ModelRunner are partially updated. "
f"Please discard the whole weights."
)
logger.error(error_msg)
return False, error_msg
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]: