Online weight updates from torch.distributed (#2279)

This commit is contained in:
Chayenne
2024-12-01 23:23:18 -08:00
committed by GitHub
parent 28bc60dcab
commit 983bfcf386
12 changed files with 1120 additions and 61 deletions

View File

@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
message: str
@dataclass
class UpdateWeightsFromDistributedReqInput:
name: str
dtype: str
shape: List[int]
@dataclass
class UpdateWeightsFromDistributedReqOutput:
success: bool
message: str
@dataclass
class InitWeightsUpdateGroupReqInput:
# The master address
master_address: str
# The master port
master_port: int
# The rank offset
rank_offset: int
# The world size
world_size: int
# The group name
group_name: str = "weight_update_group"
# The backend
backend: str = "nccl"
@dataclass
class InitWeightsUpdateGroupReqOutput:
success: bool
message: str
@dataclass
class GetWeightsByNameReqInput:
name: str

View File

@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -516,6 +520,19 @@ class Scheduler:
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
success, message = self.init_weights_update_group(recv_req)
self.send_to_tokenizer.send_pyobj(
InitWeightsUpdateGroupReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
success, message = self.update_weights_from_distributed(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
@@ -1378,6 +1395,23 @@ class Scheduler:
logger.error(message)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter

View File

@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -456,6 +460,48 @@ class TokenizerManager:
else:
return False, "Another update is in progress. Please try again later."
async def init_weights_update_group(
self,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> bool:
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.init_weights_update_group_result = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = await self.init_weights_update_group_result
return result.success, result.message
async def update_weights_from_distributed(
self,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
if not self.model_update_lock.locked():
async with self.model_update_lock:
self.send_to_scheduler.send_pyobj(obj)
self.parameter_update_result = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
return result.success, result.message
else:
logger.error(
f"Another parameter update is in progress in tokenizer manager"
)
return (
False,
"Another parameter update is in progress. Please try again later.",
)
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
@@ -546,7 +592,9 @@ class TokenizerManager:
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
@@ -558,6 +606,12 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
@@ -568,6 +622,12 @@ class TokenizerManager:
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id

View File

@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -164,6 +166,25 @@ class TpModelWorker:
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.name, recv_req.dtype, recv_req.shape
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size

View File

@@ -25,7 +25,9 @@ import torch
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.worker.init_weights_update_group(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)

View File

@@ -20,10 +20,13 @@ import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from typing import Optional, Type
from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
@@ -404,6 +408,86 @@ class ModelRunner:
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
def init_weights_update_group(
self,
master_address,
master_port,
rank_offset,
world_size,
group_name,
backend="nccl",
):
"""Initialize the Torch process group for model parameter updates.
`_model_update_group` is used in the RLHF workflow, where rank
0 is the actor model in the training engine, and the other ranks are
the inference engine, which is used for rollout.
In the RLHF workflow, the training engine updates the model
weights/parameters online, and broadcasts them to the inference
engine through the `_model_update_group` process group.
"""
assert (
torch.distributed.is_initialized()
), "Default torch process group must be initialized"
assert group_name != "", "Group name cannot be empty"
rank = rank_offset + self.tp_rank
logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
try:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
dist.barrier(group=self._model_update_group, device_ids=[rank])
return True, "Succeeded to initialize custom process group."
except Exception as e:
message = f"Failed to initialize custom process group: {e}."
logger.error(message)
return False, message
def update_weights_from_distributed(self, name, dtype, shape):
"""
Update specific parameter in the model weights online
through `_model_update_group` process group.
Args:
name: the name of the parameter to be updated.
dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated.
"""
target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
)
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
assert (
self._model_update_group is not None
), "model update group must be initialized"
try:
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
self.model.load_weights([(name, weights)])
return True, f"Succeeded to update parameter {name} online."
except Exception as e:
error_msg = (
f"Failed to update parameter online: {e}. "
f"The full weights of the ModelRunner are partially updated. "
f"Please discard the whole weights."
)
logger.error(error_msg)
return False, error_msg
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:

View File

@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module):
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
# Llama 3.2 1B Insturct set tie_word_embeddings to True
# Llama 3.1 8B Insturct set tie_word_embeddings to False
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:

View File

@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
delete_directory,
init_custom_process_group,
is_port_available,
kill_process_tree,
maybe_set_triton_cache_manager,
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
)
@app.post("/init_weights_update_group")
async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request
):
"""Initialize the parameter update group."""
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weights_from_distributed")
async def update_weights_from_distributed(
obj: UpdateWeightsFromDistributedReqInput, request: Request
):
"""Update model parameter from distributed online."""
success, message = await tokenizer_manager.update_weights_from_distributed(
obj, request
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
@time_func_latency
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
"""Handle a get parameter by name request."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
@@ -970,7 +989,51 @@ class Engine:
async def get_server_info(self):
return await _get_server_info()
def get_weights_by_name(self, name, truncate_size=100):
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
def init_weights_update_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
):
"""Initialize parameter update group."""
obj = InitWeightsUpdateGroupReqInput(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(get_weights_by_name_request(obj, None))
return loop.run_until_complete(_init_group())
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput(
name=name,
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())

View File

@@ -39,6 +39,7 @@ import numpy as np
import psutil
import requests
import torch
import torch.distributed
import torch.distributed as dist
import triton
import zmq
@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity():
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
backend=None,
init_method=None,
timeout=None,
world_size=-1,
rank=-1,
store=None,
group_name=None,
pg_options=None,
):
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
assert (store is None) or (
init_method is None
), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
def crash_on_warnings():
# Crash on warning if we are running CI tests
return get_bool_env_var("SGLANG_IS_IN_CI")