From 4ca43b061c24b8ba85d1f85bed140f1bf10c0dc2 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Sat, 2 Aug 2025 00:41:05 -0700 Subject: [PATCH] Add tensor.detach() back to update weight util (#8691) --- python/sglang/srt/weight_sync/utils.py | 2 +- test/srt/test_utils_update_weights.py | 126 ++++++++++++------------- 2 files changed, 63 insertions(+), 65 deletions(-) diff --git a/python/sglang/srt/weight_sync/utils.py b/python/sglang/srt/weight_sync/utils.py index edb7f6ea0..8f3c8adb7 100644 --- a/python/sglang/srt/weight_sync/utils.py +++ b/python/sglang/srt/weight_sync/utils.py @@ -45,7 +45,7 @@ async def update_weights( ( name, MultiprocessingSerializer.serialize( - _preprocess_tensor_for_update_weights(tensor) + _preprocess_tensor_for_update_weights(tensor.detach()) ), ) for name, tensor in params_batch diff --git a/test/srt/test_utils_update_weights.py b/test/srt/test_utils_update_weights.py index afbef6d38..03262f10a 100644 --- a/test/srt/test_utils_update_weights.py +++ b/test/srt/test_utils_update_weights.py @@ -1,10 +1,9 @@ import asyncio import os +import unittest -import pytest import torch import torch.distributed as dist -from loguru import logger from torch.distributed.device_mesh import init_device_mesh from transformers import AutoModelForCausalLM @@ -39,11 +38,29 @@ def setup_single_process_distributed(): os.environ["LOCAL_RANK"] = "0" -class TestUtilsUpdateWeights: +class TestUtilsUpdateWeights(unittest.TestCase): """Test class for utils.update_weights function""" - @pytest.fixture(scope="class") - def setup_distributed(self): + @classmethod + def setUpClass(cls): + """Setup distributed environment and test fixtures for the entire test class""" + cls.setup_distributed() + cls.setup_test_engine() + cls.setup_test_model() + cls.setup_device_mesh() + + @classmethod + def tearDownClass(cls): + """Cleanup after all tests""" + if hasattr(cls, "engine") and cls.engine: + cls.engine.shutdown() + + # Cleanup distributed + if dist.is_initialized(): + dist.destroy_process_group() + + @classmethod + def setup_distributed(cls): """Setup distributed environment for testing""" setup_single_process_distributed() @@ -53,13 +70,15 @@ class TestUtilsUpdateWeights: backend="nccl" if torch.cuda.is_available() else "gloo" ) except Exception as e: - pytest.skip(f"Could not initialize distributed backend: {e}") + raise unittest.SkipTest( + f"Could not initialize distributed backend: {e}" + ) - rank = dist.get_rank() - world_size = dist.get_world_size() + cls.rank = dist.get_rank() + cls.world_size = dist.get_world_size() if torch.cuda.is_available(): - torch.cuda.set_device(rank % torch.cuda.device_count()) + torch.cuda.set_device(cls.rank % torch.cuda.device_count()) # Set up environment variables os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -68,38 +87,26 @@ class TestUtilsUpdateWeights: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_MODULE_LOADING"] = "AUTO" - yield rank, world_size - - # Cleanup - if dist.is_initialized(): - dist.destroy_process_group() - - @pytest.fixture(scope="class") - def test_engine(self, setup_distributed): + @classmethod + def setup_test_engine(cls): """Setup test engine""" - rank, world_size = setup_distributed - - if rank == 0: - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - engine = AsyncEngine( + if cls.rank == 0: + cls.engine = AsyncEngine( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, dtype="bfloat16", mem_fraction_static=0.3, enable_memory_saver=True, - tp_size=world_size, - disable_cuda_graph=True, + tp_size=cls.world_size, + disable_cuda_graph=False, ) - yield engine - engine.shutdown() - else: - yield None + cls.engine = None - @pytest.fixture(scope="class") - def test_model(self): + @classmethod + def setup_test_model(cls): """Load test model""" try: - model = AutoModelForCausalLM.from_pretrained( + cls.model = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, device_map="cpu", trust_remote_code=True, @@ -108,25 +115,20 @@ class TestUtilsUpdateWeights: torch.float16 if torch.cuda.is_available() else torch.float32 ), ) - return model except Exception as e: - pytest.skip(f"Could not load test model: {e}") + raise unittest.SkipTest(f"Could not load test model: {e}") - @pytest.fixture(scope="class") - def device_mesh(self, setup_distributed): + @classmethod + def setup_device_mesh(cls): """Create device mesh for testing""" - rank, world_size = setup_distributed - if not torch.cuda.is_available(): - pytest.skip("CUDA not available for device mesh") + raise unittest.SkipTest("CUDA not available for device mesh") - device_mesh_key = "tp" - mesh = init_device_mesh( - "cuda", (world_size,), mesh_dim_names=(device_mesh_key,) + cls.device_mesh_key = "tp" + cls.mesh = init_device_mesh( + "cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,) ) - return device_mesh_key, mesh - def create_test_params_batch(self, model, num_params=64): """Create a batch of test parameters from the model""" param_names = [] @@ -143,31 +145,27 @@ class TestUtilsUpdateWeights: return list(zip(param_names, test_tensors)) - @pytest.mark.asyncio - async def test_utils_update_weights( - self, setup_distributed, test_engine, test_model, device_mesh - ): + def test_utils_update_weights(self): """Test basic functionality of utils.update_weights""" - rank, world_size = setup_distributed - device_mesh_key, mesh = device_mesh - # Create test parameters batch - params_batch = self.create_test_params_batch(test_model, num_params=2) + async def async_test(): + # Create test parameters batch + params_batch = self.create_test_params_batch(self.model, num_params=2) - print( - f"Rank {rank} testing utils.update_weights with {len(params_batch)} parameters" - ) - # Test the utils.update_weights function - result = await update_weights( - engine=test_engine, - params_batch=params_batch, - device_mesh_key=device_mesh_key, - device_mesh=mesh, - load_format=None, - ) + # Test the utils.update_weights function + result = await update_weights( + engine=self.engine, + params_batch=params_batch, + device_mesh_key=self.device_mesh_key, + device_mesh=self.mesh, + load_format=None, + ) - assert "Success" in result + self.assertIn("Success", result) + + # Run the async test + asyncio.run(async_test()) if __name__ == "__main__": - pytest.main([__file__]) + unittest.main()