diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d0939ffca..e6d3c9a24 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server_args import ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed @@ -278,6 +279,8 @@ class TpModelWorker: return success, message def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + + monkey_patch_torch_reductions() success, message = self.model_runner.update_weights_from_tensor( named_tensors=MultiprocessingSerializer.deserialize( recv_req.serialized_named_tensors[self.tp_rank] diff --git a/python/sglang/srt/weight_sync/utils.py b/python/sglang/srt/weight_sync/utils.py new file mode 100644 index 000000000..edb7f6ea0 --- /dev/null +++ b/python/sglang/srt/weight_sync/utils.py @@ -0,0 +1,119 @@ +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput +from sglang.srt.model_executor.model_runner import LocalSerializedTensor +from sglang.srt.utils import MultiprocessingSerializer + + +async def update_weights( + engine: Engine, + params_batch: list[tuple[str, torch.Tensor]], + device_mesh_key: str, + device_mesh: DeviceMesh, + load_format: Optional[str] = None, +): + """ + Update weights for the inference engine. + This function is designed to be stateless, so that the caller process could keep the stateful engine. + Example Use Case: + - Multiple Producer Process will call this function in a SPMD style + + Args: + engine: The inference engine created by the caller process. + params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call. + device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp" + device_mesh: The device mesh. + load_format: The format of the weights. + """ + infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0] + infer_tp_rank = device_mesh[device_mesh_key].get_local_rank() + from sglang.srt.patch_torch import monkey_patch_torch_reductions + + monkey_patch_torch_reductions() + + # [ + # (name0, ipc_tensor0_tp0), + # (name1, ipc_tensor1_tp0), + # ] + named_tensors_batch = [ + ( + name, + MultiprocessingSerializer.serialize( + _preprocess_tensor_for_update_weights(tensor) + ), + ) + for name, tensor in params_batch + ] + + if infer_tp_rank == 0: + gathered_serialized_batches = [None for _ in range(infer_tp_size)] + else: + gathered_serialized_batches = None + + # [ + # [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ], + # [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ], + # ] + dist.gather_object( + obj=named_tensors_batch, + object_gather_list=gathered_serialized_batches, + dst=device_mesh[device_mesh_key].mesh.tolist()[0], + group=device_mesh[device_mesh_key].get_group(), + ) + + if infer_tp_rank == 0: + # Use zip(*) to "transpose" the data structure. + # After transpose, the data structure is like: + # [ + # ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ), + # ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ), + # ] + logical_tensors = zip(*gathered_serialized_batches, strict=True) + + named_tensors = [ + # [ + # (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])), + # (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])), + # ] + ( + tensor_group[0][0], + LocalSerializedTensor( + values=[rank_part[1] for rank_part in tensor_group] + ), + ) + for tensor_group in logical_tensors + ] + + update_weights_request = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=[ + MultiprocessingSerializer.serialize(named_tensors) + for _ in range(infer_tp_size) + ], + load_format=load_format, + ) + + return await engine.update_weights_from_tensor(update_weights_request) + + +def _preprocess_tensor_for_update_weights(tensor: torch.Tensor): + """ + Preprocess the tensor for update weights. + Example Use Case: + - FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights + - Megatron: we do nothing here, assuming it is gathered when feed into this func + + Args: + tensor: The tensor to be preprocessed. + + Returns: + The full tensor if it is a DTensor, otherwise the original tensor. + """ + if isinstance(tensor, DTensor): + return tensor.full_tensor() + return tensor diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 18dcd004f..19ff9d560 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -101,6 +101,7 @@ suites = { TestFile("test_triton_sliding_window.py", 250), TestFile("test_update_weights_from_disk.py", 114), TestFile("test_update_weights_from_tensor.py", 48), + TestFile("test_utils_update_weights.py", 48), TestFile("test_vertex_endpoint.py", 31), TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vlm_input_format.py", 300), diff --git a/test/srt/test_utils_update_weights.py b/test/srt/test_utils_update_weights.py new file mode 100644 index 000000000..afbef6d38 --- /dev/null +++ b/test/srt/test_utils_update_weights.py @@ -0,0 +1,173 @@ +import asyncio +import os + +import pytest +import torch +import torch.distributed as dist +from loguru import logger +from torch.distributed.device_mesh import init_device_mesh +from transformers import AutoModelForCausalLM + +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.weight_sync.utils import update_weights +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class AsyncEngine(Engine): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def update_weights_from_tensor(self, update_weights_request): + return await self.tokenizer_manager.update_weights_from_tensor( + update_weights_request, None + ) + + +def is_distributed_available(): + """Check if distributed training environment is available""" + required_vars = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"] + return all(var in os.environ for var in required_vars) + + +def setup_single_process_distributed(): + """Setup distributed environment for single process testing""" + if not is_distributed_available(): + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12356" + os.environ["LOCAL_RANK"] = "0" + + +class TestUtilsUpdateWeights: + """Test class for utils.update_weights function""" + + @pytest.fixture(scope="class") + def setup_distributed(self): + """Setup distributed environment for testing""" + setup_single_process_distributed() + + if not dist.is_initialized(): + try: + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo" + ) + except Exception as e: + pytest.skip(f"Could not initialize distributed backend: {e}") + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if torch.cuda.is_available(): + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Set up environment variables + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" + + yield rank, world_size + + # Cleanup + if dist.is_initialized(): + dist.destroy_process_group() + + @pytest.fixture(scope="class") + def test_engine(self, setup_distributed): + """Setup test engine""" + rank, world_size = setup_distributed + + if rank == 0: + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" + engine = AsyncEngine( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + dtype="bfloat16", + mem_fraction_static=0.3, + enable_memory_saver=True, + tp_size=world_size, + disable_cuda_graph=True, + ) + yield engine + engine.shutdown() + + else: + yield None + + @pytest.fixture(scope="class") + def test_model(self): + """Load test model""" + try: + model = AutoModelForCausalLM.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + device_map="cpu", + trust_remote_code=True, + low_cpu_mem_usage=True, + torch_dtype=( + torch.float16 if torch.cuda.is_available() else torch.float32 + ), + ) + return model + except Exception as e: + pytest.skip(f"Could not load test model: {e}") + + @pytest.fixture(scope="class") + def device_mesh(self, setup_distributed): + """Create device mesh for testing""" + rank, world_size = setup_distributed + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available for device mesh") + + device_mesh_key = "tp" + mesh = init_device_mesh( + "cuda", (world_size,), mesh_dim_names=(device_mesh_key,) + ) + + return device_mesh_key, mesh + + def create_test_params_batch(self, model, num_params=64): + """Create a batch of test parameters from the model""" + param_names = [] + test_tensors = [] + + # Get first few parameters from the model for testing + for i, (name, tensor) in enumerate(model.named_parameters()): + if i >= num_params: + break + param_names.append(name) + # Create test tensor with known values, matching original shape and dtype + test_tensor = torch.full_like(tensor, 1.5, dtype=tensor.dtype).cuda() + test_tensors.append(test_tensor) + + return list(zip(param_names, test_tensors)) + + @pytest.mark.asyncio + async def test_utils_update_weights( + self, setup_distributed, test_engine, test_model, device_mesh + ): + """Test basic functionality of utils.update_weights""" + rank, world_size = setup_distributed + device_mesh_key, mesh = device_mesh + + # Create test parameters batch + params_batch = self.create_test_params_batch(test_model, num_params=2) + + print( + f"Rank {rank} testing utils.update_weights with {len(params_batch)} parameters" + ) + # Test the utils.update_weights function + result = await update_weights( + engine=test_engine, + params_batch=params_batch, + device_mesh_key=device_mesh_key, + device_mesh=mesh, + load_format=None, + ) + + assert "Success" in result + + +if __name__ == "__main__": + pytest.main([__file__])