From 983bfcf386861812aeaf1f0495371549a94b01c1 Mon Sep 17 00:00:00 2001 From: Chayenne Date: Sun, 1 Dec 2024 23:23:18 -0800 Subject: [PATCH] Online weight updates from torch.distributed (#2279) --- .github/workflows/pr-test.yml | 7 + python/sglang/srt/managers/io_struct.py | 35 + python/sglang/srt/managers/scheduler.py | 34 + .../sglang/srt/managers/tokenizer_manager.py | 60 ++ python/sglang/srt/managers/tp_worker.py | 21 + .../srt/managers/tp_worker_overlap_thread.py | 12 + .../sglang/srt/model_executor/model_runner.py | 86 ++- python/sglang/srt/models/llama.py | 2 + python/sglang/srt/server.py | 93 ++- python/sglang/srt/utils.py | 73 +++ test/srt/test_get_weights_by_name.py | 144 ++-- .../test_update_weights_from_distributed.py | 614 ++++++++++++++++++ 12 files changed, 1120 insertions(+), 61 deletions(-) create mode 100644 test/srt/test_update_weights_from_distributed.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index b10d62eec..59f0006e1 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -27,6 +27,7 @@ concurrency: cancel-in-progress: true jobs: + unit-test-frontend: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner @@ -98,6 +99,11 @@ jobs: python3 test_mla_fp8.py python3 test_dp_attention.py + - name: Test update weights from distributed + timeout-minutes: 10 + run: | + cd test/srt + python3 test_update_weights_from_distributed.py performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -245,6 +251,7 @@ jobs: cd test/srt python3 test_moe_eval_accuracy_large.py + finish: needs: [ unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 058e930ed..27bf5a4bd 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput: message: str +@dataclass +class UpdateWeightsFromDistributedReqInput: + name: str + dtype: str + shape: List[int] + + +@dataclass +class UpdateWeightsFromDistributedReqOutput: + success: bool + message: str + + +@dataclass +class InitWeightsUpdateGroupReqInput: + # The master address + master_address: str + # The master port + master_port: int + # The rank offset + rank_offset: int + # The world size + world_size: int + # The group name + group_name: str = "weight_update_group" + # The backend + backend: str = "nccl" + + +@dataclass +class InitWeightsUpdateGroupReqOutput: + success: bool + message: str + + @dataclass class GetWeightsByNameReqInput: name: str diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 41895e067..16e7691c1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import ( TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromDistributedReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -516,6 +520,19 @@ class Scheduler: elif isinstance(recv_req, GetWeightsByNameReqInput): parameter = self.get_weights_by_name(recv_req) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) + elif isinstance(recv_req, InitWeightsUpdateGroupReqInput): + success, message = self.init_weights_update_group(recv_req) + self.send_to_tokenizer.send_pyobj( + InitWeightsUpdateGroupReqOutput(success, message) + ) + elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput): + success, message = self.update_weights_from_distributed(recv_req) + self.send_to_tokenizer.send_pyobj( + UpdateWeightsFromDistributedReqOutput(success, message) + ) + elif isinstance(recv_req, GetWeightsByNameReqInput): + parameter = self.get_weights_by_name(recv_req) + self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: self.start_profile() @@ -1378,6 +1395,23 @@ class Scheduler: logger.error(message) return success, message + def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): + """Initialize the online model parameter update group.""" + success, message = self.tp_worker.init_weights_update_group(recv_req) + return success, message + + def update_weights_from_distributed( + self, recv_req: UpdateWeightsFromDistributedReqInput + ): + """Update the online model parameter.""" + success, message = self.tp_worker.update_weights_from_distributed(recv_req) + if success: + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return parameter diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 630c5ec42..3ba5f210b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import ( GenerateReqInput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import ( TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromDistributedReqOutput, ) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams @@ -456,6 +460,48 @@ class TokenizerManager: else: return False, "Another update is in progress. Please try again later." + async def init_weights_update_group( + self, + obj: InitWeightsUpdateGroupReqInput, + request: Optional[fastapi.Request] = None, + ) -> bool: + if self.to_create_loop: + self.create_handle_loop() + self.send_to_scheduler.send_pyobj(obj) + + self.init_weights_update_group_result = asyncio.Future() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = await self.init_weights_update_group_result + return result.success, result.message + + async def update_weights_from_distributed( + self, + obj: UpdateWeightsFromDistributedReqInput, + request: Optional[fastapi.Request] = None, + ): + if self.to_create_loop: + self.create_handle_loop() + + if not self.model_update_lock.locked(): + async with self.model_update_lock: + self.send_to_scheduler.send_pyobj(obj) + self.parameter_update_result = asyncio.Future() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" + result = await self.parameter_update_result + return result.success, result.message + else: + logger.error( + f"Another parameter update is in progress in tokenizer manager" + ) + return ( + False, + "Another parameter update is in progress. Please try again later.", + ) + async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None ): @@ -546,7 +592,9 @@ class TokenizerManager: BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightFromDiskReqOutput, + UpdateWeightsFromDistributedReqOutput, GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): @@ -558,6 +606,12 @@ class TokenizerManager: if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) continue + elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for update weights from distributed" + self.parameter_update_result.set_result(recv_obj) + continue elif isinstance(recv_obj, GetWeightsByNameReqOutput): if self.server_args.dp_size == 1: self.get_weights_by_name_result.set_result(recv_obj) @@ -568,6 +622,12 @@ class TokenizerManager: self.get_weights_by_name_tmp ) continue + elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + self.init_weights_update_group_result.set_result(recv_obj) + continue elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d79498c77..43d82c1a0 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -164,6 +166,25 @@ class TpModelWorker: ) return success, message + def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): + success, message = self.model_runner.init_weights_update_group( + recv_req.master_address, + recv_req.master_port, + recv_req.rank_offset, + recv_req.world_size, + recv_req.group_name, + recv_req.backend, + ) + return success, message + + def update_weights_from_distributed( + self, recv_req: UpdateWeightsFromDistributedReqInput + ): + success, message = self.model_runner.update_weights_from_distributed( + recv_req.name, recv_req.dtype, recv_req.shape + ) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 1b0be30df..e4e20ad8f 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -25,7 +25,9 @@ import torch from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker @@ -211,6 +213,16 @@ class TpModelWorkerClient: success, message = self.worker.update_weights_from_disk(recv_req) return success, message + def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): + success, message = self.worker.init_weights_update_group(recv_req) + return success, message + + def update_weights_from_distributed( + self, recv_req: UpdateWeightsFromDistributedReqInput + ): + success, message = self.worker.update_weights_from_distributed(recv_req) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): return self.worker.get_weights_by_name(recv_req) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e8ea6a163..5c4f5c81b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,10 +20,13 @@ import inspect import json import logging import pkgutil +import time from functools import lru_cache -from typing import Optional, Type +from tokenize import tabsize +from typing import Any, Optional, Type, Union import torch +import torch.distributed as dist import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig @@ -59,6 +62,7 @@ from sglang.srt.utils import ( crash_on_warnings, enable_show_time_cost, get_available_gpu_memory, + init_custom_process_group, is_hip, monkey_patch_vllm_gguf_config, monkey_patch_vllm_model_config, @@ -404,6 +408,86 @@ class ModelRunner: logger.info("Update weights end.") return True, "Succeeded to update model weights." + def init_weights_update_group( + self, + master_address, + master_port, + rank_offset, + world_size, + group_name, + backend="nccl", + ): + """Initialize the Torch process group for model parameter updates. + + `_model_update_group` is used in the RLHF workflow, where rank + 0 is the actor model in the training engine, and the other ranks are + the inference engine, which is used for rollout. + + In the RLHF workflow, the training engine updates the model + weights/parameters online, and broadcasts them to the inference + engine through the `_model_update_group` process group. + """ + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + assert group_name != "", "Group name cannot be empty" + + rank = rank_offset + self.tp_rank + + logger.info( + f"init custom process group: master_address={master_address}, master_port={master_port}, " + f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + ) + + try: + self._model_update_group = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + dist.barrier(group=self._model_update_group, device_ids=[rank]) + return True, "Succeeded to initialize custom process group." + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + logger.error(message) + return False, message + + def update_weights_from_distributed(self, name, dtype, shape): + """ + Update specific parameter in the model weights online + through `_model_update_group` process group. + + Args: + name: the name of the parameter to be updated. + dtype: the data type of the parameter to be updated. + shape: the shape of the parameter to be updated. + """ + target_dtype = ( + dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + ) + current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype + + assert ( + self._model_update_group is not None + ), "model update group must be initialized" + + try: + weights = torch.empty(shape, dtype=target_dtype, device=self.device) + torch.distributed.broadcast(weights, src=0, group=self._model_update_group) + self.model.load_weights([(name, weights)]) + return True, f"Succeeded to update parameter {name} online." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + logger.error(error_msg) + return False, error_msg + def get_weights_by_name( self, name: str, truncate_size: int = 100 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 68809c9c2..62ad0d2a0 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module): self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) + # Llama 3.2 1B Insturct set tie_word_embeddings to True + # Llama 3.1 8B Insturct set tie_word_embeddings to False if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 71755654c..a750d90e2 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, OpenSessionReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -80,6 +82,7 @@ from sglang.srt.utils import ( assert_pkg_version, configure_logger, delete_directory, + init_custom_process_group, is_port_available, kill_process_tree, maybe_set_triton_cache_manager, @@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R ) +@app.post("/init_weights_update_group") +async def init_weights_update_group( + obj: InitWeightsUpdateGroupReqInput, request: Request +): + """Initialize the parameter update group.""" + success, message = await tokenizer_manager.init_weights_update_group(obj, request) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed( + obj: UpdateWeightsFromDistributedReqInput, request: Request +): + """Update model parameter from distributed online.""" + success, message = await tokenizer_manager.update_weights_from_distributed( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + @app.api_route("/get_weights_by_name", methods=["GET", "POST"]) async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): """Get model parameter by name.""" @@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request): ) -@time_func_latency -async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request): - """Handle a get parameter by name request.""" - try: - ret = await tokenizer_manager.get_weights_by_name(obj, request) - return ret - except ValueError as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - @app.api_route("/encode", methods=["POST", "PUT"]) @time_func_latency async def encode_request(obj: EmbeddingReqInput, request: Request): @@ -970,7 +989,51 @@ class Engine: async def get_server_info(self): return await _get_server_info() - def get_weights_by_name(self, name, truncate_size=100): - obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + """Initialize parameter update group.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + + async def _init_group(): + return await tokenizer_manager.init_weights_update_group(obj, None) + loop = asyncio.get_event_loop() - return loop.run_until_complete(get_weights_by_name_request(obj, None)) + return loop.run_until_complete(_init_group()) + + def update_weights_from_distributed(self, name, dtype, shape): + """Update weights from distributed source.""" + obj = UpdateWeightsFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + ) + + async def _update_weights(): + return await tokenizer_manager.update_weights_from_distributed(obj, None) + + loop = asyncio.get_event_loop() + return loop.run_until_complete(_update_weights()) + + def get_weights_by_name(self, name, truncate_size=100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + + async def _get_weights(): + return await tokenizer_manager.get_weights_by_name(obj, None) + + loop = asyncio.get_event_loop() + return loop.run_until_complete(_get_weights()) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 19ea78015..fef4c58a5 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -39,6 +39,7 @@ import numpy as np import psutil import requests import torch +import torch.distributed import torch.distributed as dist import triton import zmq @@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity(): ) +# Copy from pytorch and OpenRLHF to allow creating multiple main groups. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py +# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or ( + init_method is None + ), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg + + def crash_on_warnings(): # Crash on warning if we are running CI tests return get_bool_env_var("SGLANG_IS_IN_CI") diff --git a/test/srt/test_get_weights_by_name.py b/test/srt/test_get_weights_by_name.py index 6579646f4..1494483c7 100644 --- a/test/srt/test_get_weights_by_name.py +++ b/test/srt/test_get_weights_by_name.py @@ -8,47 +8,46 @@ from transformers import AutoModelForCausalLM import sglang as sgl from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + is_in_ci, popen_launch_server, ) from sglang.utils import terminate_process +def _process_return(ret): + if isinstance(ret, list) and len(ret) == 2: + print(f"running assert_allclose on data parallel") + np.testing.assert_allclose(ret[0], ret[1]) + return np.array(ret[0]) + return np.array(ret) + + class TestGetWeightsByName(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.hf_model = AutoModelForCausalLM.from_pretrained( - cls.model, torch_dtype="bfloat16" + + def init_hf_model(self, model_name, tie_word_embeddings): + self.hf_model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings ).to("cuda:0") - @classmethod - def tearDownClass(cls): - del cls.hf_model - gc.collect() - torch.cuda.empty_cache() - - def init_backend(self, backend, dp, tp): - self.engine = None - self.process = None + def init_backend(self, backend, dp, tp, model_name): self.backend = backend self.dp = dp self.tp = tp if backend == "Engine": self.engine = sgl.Engine( - model_path=self.model, + model_path=model_name, random_seed=42, - tp_size=self.tp, - dp_size=self.dp, - mem_fraction_static=0.85, + tp_size=tp, + dp_size=dp, ) else: self.process = popen_launch_server( - self.model, - self.base_url, + model_name, + DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=( "--tp-size", @@ -58,12 +57,50 @@ class TestGetWeightsByName(unittest.TestCase): ), ) - def close_engine_and_server(self): - if self.engine: + def clean_up(self): + del self.hf_model + gc.collect() + torch.cuda.empty_cache() + if self.backend == "Engine": self.engine.shutdown() - if self.process: + else: terminate_process(self.process) + def assert_tie_word_embeddings(self, truncate_size): + print(f"assert_tie_word_embeddings") + if self.backend == "Engine": + backend_ret = _process_return( + self.engine.get_weights_by_name("lm_head.weight", truncate_size) + ) + else: + backend_ret = _process_return( + requests.get( + f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name", + json={"name": "lm_head.weight", "truncate_size": truncate_size}, + ).json() + ) + print(f"assert_tie_word_embeddings of hf and backend") + assert np.allclose( + self.hf_model.get_parameter("model.embed_tokens.weight") + .cpu() + .detach() + .float() + .numpy()[:truncate_size], + backend_ret, + ) + assert np.allclose( + self.hf_model.get_parameter("lm_head.weight") + .cpu() + .detach() + .float() + .numpy()[:truncate_size], + self.hf_model.get_parameter("model.embed_tokens.weight") + .cpu() + .detach() + .float() + .numpy()[:truncate_size], + ) + def assert_weights_all_close(self, param_name, truncate_size): print( f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}" @@ -73,34 +110,38 @@ class TestGetWeightsByName(unittest.TestCase): if self.backend == "Engine": engine_ret = self.engine.get_weights_by_name(param_name, truncate_size) - engine_ret = self._process_return(engine_ret) + engine_ret = _process_return(engine_ret) np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5) if self.backend == "Runtime": runtime_ret = requests.get( - f"{self.base_url}/get_weights_by_name", + f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name", json={"name": param_name, "truncate_size": truncate_size}, ).json() - runtime_ret = self._process_return(runtime_ret) + runtime_ret = _process_return(runtime_ret) np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5) - @staticmethod - def _process_return(ret): - if isinstance(ret, list) and len(ret) == 2: - print("running assert_allclose on data parallel") - np.testing.assert_allclose(ret[0], ret[1]) - return np.array(ret[0]) - return np.array(ret) + def test_get_weights_by_name(self): + if is_in_ci(): + test_suits = [ + ("Engine", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), + ] + else: + test_suits = [ + ("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), + ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST), + ] + if torch.cuda.device_count() >= 2: + test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST)) + test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST)) - def test_get_parameters_by_name(self): - test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)] - - if torch.cuda.device_count() >= 2: - test_suits.append(("Engine", 1, 2)) - test_suits.append(("Runtime", 2, 1)) - - if torch.cuda.device_count() >= 4: - test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)]) + if torch.cuda.device_count() >= 4: + test_suits.extend( + [ + ("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), + ("Runtime", 2, 2, DEFAULT_MODEL_NAME_FOR_TEST), + ] + ) parameters = [ "model.embed_tokens.weight", @@ -117,11 +158,24 @@ class TestGetWeightsByName(unittest.TestCase): "lm_head.weight", ] + truncate_size = 100 + for test_suit in test_suits: + if test_suit[-1] == DEFAULT_MODEL_NAME_FOR_TEST: + tie_word_embeddings = False + else: + tie_word_embeddings = True + + self.init_hf_model(test_suit[-1], tie_word_embeddings) self.init_backend(*test_suit) + for param_name in parameters: - self.assert_weights_all_close(param_name, 100) - self.close_engine_and_server() + self.assert_weights_all_close(param_name, truncate_size) + + if tie_word_embeddings: + self.assert_tie_word_embeddings(truncate_size) + + self.clean_up() if __name__ == "__main__": diff --git a/test/srt/test_update_weights_from_distributed.py b/test/srt/test_update_weights_from_distributed.py new file mode 100644 index 000000000..a4fe17813 --- /dev/null +++ b/test/srt/test_update_weights_from_distributed.py @@ -0,0 +1,614 @@ +"""Test distributed weight updates. + +This test suite simulates a distributed training environment to ensure +correct weight synchronization. On rank 0, the instruct model represents +pre-training weights, and the base model represents post-training weights. +The base model's weights are broadcasted to other ranks using the online +weight update API. + +On other ranks, an engine is initialized with the instruct model, and its +parameters are verified against the Hugging Face model. After updating +weights from the distributed system, post-training weights are loaded +and verified again to ensure consistency and accuracy across the +distributed setup. +""" + +import gc +import os +import time +import unittest + +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from transformers import AutoModelForCausalLM + +import sglang as sgl +from sglang.srt.utils import init_custom_process_group +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, +) +from sglang.utils import terminate_process + +mp.set_start_method("spawn", force=True) + + +def verify_params_close(params1, params2, error_msg): + """Verify if two parameter arrays are close enough.""" + try: + assert np.allclose(np.array(params1), np.array(params2)), error_msg + except Exception as e: + print(f"Parameters not close for {error_msg}") + print("Params1:", np.array(params1)) + print("Params2:", np.array(params2)) + raise e + + +def verify_params_not_close(params1, params2, error_msg): + """Verify if two parameter arrays are different enough.""" + assert not np.allclose(np.array(params1), np.array(params2)), error_msg + + +def init_process( + rank, + world_size, + param_queue, + truncate_size, + state_dict_key_to_shape, + tp_size, + model_name, + backend, + checking_parameters, + tie_word_embeddings, +): + torch.cuda.set_device(rank) + + if rank == 0: + init_process_hf( + rank, + world_size, + param_queue, + truncate_size, + model_name, + checking_parameters, + tie_word_embeddings, + state_dict_key_to_shape, + ) + elif rank in [1, 2]: + init_process_sgl( + rank, + world_size, + param_queue, + truncate_size, + model_name, + checking_parameters, + tie_word_embeddings, + state_dict_key_to_shape, + backend, + tp_size, + ) + + +def init_process_hf( + rank, + world_size, + param_queue, + truncate_size, + model_name, + checking_parameters, + tie_word_embeddings, + state_dict_key_to_shape, +): + # These two environment variables are very important + # to avoid unexpected behaviors of CUDA and NCCL. + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + + # Load model and get parameters + hf_instruct_model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="bfloat16", + tie_word_embeddings=tie_word_embeddings, + ).to("cuda:0") + base_model_name = model_name.replace("-Instruct", "") + hf_base_model = AutoModelForCausalLM.from_pretrained( + base_model_name, + torch_dtype="bfloat16", + tie_word_embeddings=tie_word_embeddings, + ).to("cuda:0") + + hf_instruct_params = [] + hf_base_params = [] + + print(f"get parameter in hf instruct model and base model") + for parameter_name in checking_parameters: + hf_instruct_params.append( + hf_instruct_model.get_parameter(parameter_name)[:truncate_size] + .cpu() + .detach() + .float() + .numpy() + .tolist() + ) + hf_base_params.append( + hf_base_model.get_parameter(parameter_name)[:truncate_size] + .cpu() + .detach() + .float() + .numpy() + .tolist() + ) + + param_queue.put(("hf_instruct_params", hf_instruct_params)) + param_queue.put(("hf_base_params", hf_base_params)) + + # Init weight update group for rank 0 (the training engine in RLHF). + print(f"rank {rank} world_size: {world_size} init custom process group") + group = init_custom_process_group( + backend="nccl", + init_method="tcp://localhost:65500", + world_size=world_size, + rank=rank, + group_name="test_parameter_update_group", + ) + dist.barrier(group=group, device_ids=[rank]) + torch.cuda.synchronize() + time_begin_broadcast = time.time() + + # The last parameter is lm_head.weight, which is tied + # with embed_tokens.weight. Actually, we only need + # to broadcast embed_tokens.weight once. + broadcast_parameters = list(state_dict_key_to_shape.keys()) + if tie_word_embeddings: + broadcast_parameters.remove("lm_head.weight") + + # Broadcast all the weights from the training + # engine to other ranks (inference engine). + for parameter_name in broadcast_parameters: + torch.distributed.broadcast( + hf_base_model.get_parameter(parameter_name), + src=0, + group=group, + ) + torch.cuda.synchronize() + time_end_broadcast = time.time() + + # Measure the latency of broadcasting/weights update. + broadcast_time = time_end_broadcast - time_begin_broadcast + print(f"rank {rank} broadcast parameter time: {broadcast_time:.3f}s") + param_queue.put(("broadcast_time", broadcast_time)) + + # Delete the huggingface models to free up memory. + + del hf_instruct_model + del hf_base_model + gc.collect() + torch.cuda.empty_cache() + + +def init_process_sgl( + rank, + world_size, + param_queue, + truncate_size, + model_name, + checking_parameters, + tie_word_embeddings, + state_dict_key_to_shape, + backend, + tp_size, +): + torch.cuda.set_device(rank) + torch.cuda.synchronize() + base_gpu_id = 1 if rank == 1 else 1 + tp_size + if backend == "Engine": + engine = sgl.Engine( + model_path=model_name, + random_seed=42, + base_gpu_id=base_gpu_id, + tp_size=tp_size, + ) + else: + if rank == 1: + url = DEFAULT_URL_FOR_TEST + else: + url = DEFAULT_URL_FOR_TEST.replace("2157", "2159") + process = popen_launch_server( + model_name, + url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--base-gpu-id", + str(base_gpu_id), + "--tp-size", + str(tp_size), + ), + ) + torch.cuda.synchronize() + if backend == "Engine": + print(f"rank {rank} init engine") + else: + print(f"rank {rank} init server on url: {url}") + + # Get weights of instruct model, i.e. pre-training weights. + + instruct_params = [] + for parameter_name in checking_parameters: + instruct_params.append( + engine.get_weights_by_name(parameter_name, truncate_size) + if backend == "Engine" + else requests.get( + f"{url}/get_weights_by_name", + json={"name": parameter_name, "truncate_size": truncate_size}, + ).json() + ) + + param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params)) + + # Init weight update group with the training engine. + + if backend == "Engine": + engine.init_weights_update_group( + master_address="localhost", + master_port="65500", + rank_offset=base_gpu_id, + world_size=world_size, + group_name="test_parameter_update_group", + backend="nccl", + ) + else: + requests.post( + f"{url}/init_weights_update_group", + json={ + "master_address": "localhost", + "master_port": "65500", + "rank_offset": base_gpu_id, + "world_size": world_size, + "group_name": "test_parameter_update_group", + "backend": "nccl", + }, + ) + + torch.cuda.synchronize() + time_begin_update = time.time() + + # The last parameter is lm_head.weight, which is tied + # with embed_tokens.weight. Actually, we only need + # to update embed_tokens.weight once. + + tie_word_embeddings = ( + True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False + ) + update_parameters = list(state_dict_key_to_shape.keys()) + if tie_word_embeddings: + update_parameters.remove("lm_head.weight") + + # Get weights from the training engine and update the inference engine. + + for parameter_name in update_parameters: + if backend == "Engine": + engine.update_weights_from_distributed( + parameter_name, + dtype=torch.bfloat16, + shape=state_dict_key_to_shape[parameter_name], + ) + else: + requests.post( + f"{url}/update_weights_from_distributed", + json={ + "name": parameter_name, + "dtype": "bfloat16", + "shape": state_dict_key_to_shape[parameter_name], + }, + ) + torch.cuda.synchronize() + time_end_update = time.time() + + # Measure the latency of broadcast/weights update. + + update_time = time_end_update - time_begin_update + print( + f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s" + ) + param_queue.put((f"update_sgl_dp_{rank}_time", update_time)) + + # Get the weights of post-training model after weights update for correctness check. + + base_params = [] + for parameter_name in checking_parameters: + if backend == "Engine": + base_params.append( + engine.get_weights_by_name(parameter_name, truncate_size) + ) + else: + base_params.append( + requests.get( + f"{url}/get_weights_by_name", + json={ + "name": parameter_name, + "truncate_size": truncate_size, + }, + ).json() + ) + param_queue.put((f"sgl_dp_{rank}_base_params", base_params)) + + # Shutdown the engine or terminate the server process. + + if backend == "Engine": + engine.shutdown() + else: + terminate_process(process) + + +def assert_tied_weights(params_list, message, should_be_tied): + for params in params_list: + if should_be_tied: + assert np.allclose(params[0], params[-1]), message + else: + assert not np.allclose(params[0], params[-1]), message + + +def test_update_weights_from_distributed( + tp_size, + dp_size, + model_name, + backend, + state_dict_key_to_shape, + truncate_size, + checking_parameters, +): + tie_word_embeddings = ( + True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False + ) + + print( + f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backend}" + ) + param_queue = mp.Queue() + results = {} + + context = mp.spawn( + init_process, + args=( + 1 + tp_size * dp_size, + param_queue, + truncate_size, + state_dict_key_to_shape, + tp_size, + model_name, + backend, + checking_parameters, + tie_word_embeddings, + ), + nprocs=1 + dp_size, + join=False, + ) + + while len(results) < 3 * (1 + dp_size): + try: + key, value = param_queue.get(timeout=5) + results[key] = value + except Exception as e: + if all(not p.is_alive() for p in context.processes): + break + + context.join() + + if len(results) != 3 * (1 + dp_size): + raise RuntimeError( + f"Expected {3 * (1 + dp_size)} parameters but got {len(results)}" + ) + + params = { + "hf_instruct": results.get("hf_instruct_params"), + "hf_base": results.get("hf_base_params"), + "sgl_dp_1_instruct": results.get("sgl_dp_1_instruct_params"), + "sgl_dp_1_base": results.get("sgl_dp_1_base_params"), + "broadcast_time": results.get("broadcast_time"), + "update_sgl_dp_1_time": results.get("update_sgl_dp_1_time"), + } + + if dp_size == 2: + dp2_params = { + "sgl_dp_2_instruct": results.get("sgl_dp_2_instruct_params"), + "sgl_dp_2_base": results.get("sgl_dp_2_base_params"), + "update_sgl_dp_2_time": results.get("update_sgl_dp_2_time"), + } + assert all(v is not None for v in dp2_params.values()) + params.update(dp2_params) + + # Check the correctness of weights update by verifying + # the weights of instruct model and base model. + + for i in range(len(params["hf_instruct"])): + verify_params_close( + params["hf_instruct"][i], + params["sgl_dp_1_instruct"][i], + f"sgl_dp_1_instruct_params rank {i}", + ) + + verify_params_close( + params["hf_base"][i], + params["sgl_dp_1_base"][i], + f"sgl_dp_1_base_params rank {i}", + ) + + verify_params_not_close( + params["hf_instruct"][i], + params["hf_base"][i], + f"hf_instruct_params rank {i}", + ) + + if dp_size == 2: + verify_params_close( + params["hf_base"][i], + params["sgl_dp_2_base"][i], + f"sgl_dp_2_base_params rank {i}", + ) + verify_params_close( + params["hf_instruct"][i], + params["sgl_dp_2_instruct"][i], + f"sgl_dp_2_instruct_params rank {i}", + ) + + assert len(params["hf_instruct"]) == len( + params["hf_base"] + ), "hf_instruct_params and hf_base_params have different lengths" + + # Check if the weights of lm_head are tied with embed_tokens. + + params_to_check = [ + ( + params["hf_instruct"], + "lm_head.weight is not tied with embed_tokens.weight", + ), + ( + params["hf_base"], + "lm_head.weight is not tied with embed_tokens.weight", + ), + ( + params["sgl_dp_1_instruct"], + "lm_head.weight is not tied with embed_tokens.weight", + ), + ( + params["sgl_dp_1_base"], + "lm_head.weight is not tied with embed_tokens.weight", + ), + ] + + if dp_size == 2: + params_to_check.extend( + [ + ( + params["sgl_dp_2_instruct"], + "lm_head.weight is not tied with embed_tokens.weight", + ), + ( + params["sgl_dp_2_base"], + "lm_head.weight is not tied with embed_tokens.weight", + ), + ] + ) + + assert_tied_weights( + [params for params, _ in params_to_check], + ( + "lm_head.weight is not tied with embed_tokens.weight" + if tie_word_embeddings + else "lm_head.weight is tied with embed_tokens.weight" + ), + tie_word_embeddings, + ) + + # Time limit for broadcast and update on CI is 3 / 6 + # On local H100, it's 1 / 2 + + time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6 + + assert ( + params["broadcast_time"] < time_limit + ), f"broadcast_time exceeds time limit {time_limit}s" + + assert ( + params["update_sgl_dp_1_time"] < time_limit + ), f"update_sgl_dp_one_time exceeds time limit {time_limit}s" + + if dp_size == 2: + assert ( + params["update_sgl_dp_2_time"] < time_limit + ), f"update_sgl_dp_two_time exceeds time limit {time_limit}s" + + # Delete the context and close the parameter queue. + + del context + param_queue.close() + param_queue.join_thread() + gc.collect() + torch.cuda.empty_cache() + + +class TestUpdateWeightsFromDistributed(unittest.TestCase): + + def test_update_weights_from_distributed(self): + + assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required" + # test_suits : tp, dp, model_name, backend + if is_in_ci(): + test_suits = [ + (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"), + ] + else: + test_suits = [ + (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"), + (1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever"), + ] + + if torch.cuda.device_count() >= 4: + test_suits.extend( + [ + (2, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"), + (1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"), + ] + ) + + if torch.cuda.device_count() >= 5: + test_suits.extend( + [ + (2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"), + (2, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"), + ] + ) + + model_state_dict_shapes = {} + test_models = [test_suit[2] for test_suit in test_suits] + + for model_name in test_models: + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="bfloat16" + ).to("cuda:0") + state_dict = model.state_dict() + state_dict_keys = list(state_dict.keys()) + model_state_dict_shapes[model_name] = { + key: state_dict[key].shape for key in state_dict_keys + } + del model + gc.collect() + torch.cuda.empty_cache() + + truncate_size = 10 + checking_parameters = [ + "model.embed_tokens.weight", + "model.layers.0.input_layernorm.weight", + "model.layers.1.self_attn.q_proj.weight", + "model.layers.2.self_attn.k_proj.weight", + "model.layers.3.self_attn.v_proj.weight", + "model.layers.4.self_attn.o_proj.weight", + "model.layers.5.mlp.gate_proj.weight", + "model.layers.6.mlp.up_proj.weight", + "model.layers.7.mlp.down_proj.weight", + "model.layers.8.post_attention_layernorm.weight", + "model.norm.weight", + "lm_head.weight", + ] + + for tp_size, dp_size, model_name, backend in test_suits: + test_update_weights_from_distributed( + tp_size, + dp_size, + model_name, + backend, + model_state_dict_shapes[model_name], + truncate_size, + checking_parameters, + ) + + +if __name__ == "__main__": + unittest.main()