Add tensor.detach() back to update weight util (#8691)

This commit is contained in:
Stefan He
2025-08-02 00:41:05 -07:00
committed by GitHub
parent ea93079b30
commit 4ca43b061c
2 changed files with 63 additions and 65 deletions

View File

@@ -45,7 +45,7 @@ async def update_weights(
( (
name, name,
MultiprocessingSerializer.serialize( MultiprocessingSerializer.serialize(
_preprocess_tensor_for_update_weights(tensor) _preprocess_tensor_for_update_weights(tensor.detach())
), ),
) )
for name, tensor in params_batch for name, tensor in params_batch

View File

@@ -1,10 +1,9 @@
import asyncio import asyncio
import os import os
import unittest
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
@@ -39,11 +38,29 @@ def setup_single_process_distributed():
os.environ["LOCAL_RANK"] = "0" os.environ["LOCAL_RANK"] = "0"
class TestUtilsUpdateWeights: class TestUtilsUpdateWeights(unittest.TestCase):
"""Test class for utils.update_weights function""" """Test class for utils.update_weights function"""
@pytest.fixture(scope="class") @classmethod
def setup_distributed(self): def setUpClass(cls):
"""Setup distributed environment and test fixtures for the entire test class"""
cls.setup_distributed()
cls.setup_test_engine()
cls.setup_test_model()
cls.setup_device_mesh()
@classmethod
def tearDownClass(cls):
"""Cleanup after all tests"""
if hasattr(cls, "engine") and cls.engine:
cls.engine.shutdown()
# Cleanup distributed
if dist.is_initialized():
dist.destroy_process_group()
@classmethod
def setup_distributed(cls):
"""Setup distributed environment for testing""" """Setup distributed environment for testing"""
setup_single_process_distributed() setup_single_process_distributed()
@@ -53,13 +70,15 @@ class TestUtilsUpdateWeights:
backend="nccl" if torch.cuda.is_available() else "gloo" backend="nccl" if torch.cuda.is_available() else "gloo"
) )
except Exception as e: except Exception as e:
pytest.skip(f"Could not initialize distributed backend: {e}") raise unittest.SkipTest(
f"Could not initialize distributed backend: {e}"
)
rank = dist.get_rank() cls.rank = dist.get_rank()
world_size = dist.get_world_size() cls.world_size = dist.get_world_size()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(cls.rank % torch.cuda.device_count())
# Set up environment variables # Set up environment variables
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -68,38 +87,26 @@ class TestUtilsUpdateWeights:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO" os.environ["CUDA_MODULE_LOADING"] = "AUTO"
yield rank, world_size @classmethod
def setup_test_engine(cls):
# Cleanup
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(scope="class")
def test_engine(self, setup_distributed):
"""Setup test engine""" """Setup test engine"""
rank, world_size = setup_distributed if cls.rank == 0:
cls.engine = AsyncEngine(
if rank == 0:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
engine = AsyncEngine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
dtype="bfloat16", dtype="bfloat16",
mem_fraction_static=0.3, mem_fraction_static=0.3,
enable_memory_saver=True, enable_memory_saver=True,
tp_size=world_size, tp_size=cls.world_size,
disable_cuda_graph=True, disable_cuda_graph=False,
) )
yield engine
engine.shutdown()
else: else:
yield None cls.engine = None
@pytest.fixture(scope="class") @classmethod
def test_model(self): def setup_test_model(cls):
"""Load test model""" """Load test model"""
try: try:
model = AutoModelForCausalLM.from_pretrained( cls.model = AutoModelForCausalLM.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
device_map="cpu", device_map="cpu",
trust_remote_code=True, trust_remote_code=True,
@@ -108,25 +115,20 @@ class TestUtilsUpdateWeights:
torch.float16 if torch.cuda.is_available() else torch.float32 torch.float16 if torch.cuda.is_available() else torch.float32
), ),
) )
return model
except Exception as e: except Exception as e:
pytest.skip(f"Could not load test model: {e}") raise unittest.SkipTest(f"Could not load test model: {e}")
@pytest.fixture(scope="class") @classmethod
def device_mesh(self, setup_distributed): def setup_device_mesh(cls):
"""Create device mesh for testing""" """Create device mesh for testing"""
rank, world_size = setup_distributed
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("CUDA not available for device mesh") raise unittest.SkipTest("CUDA not available for device mesh")
device_mesh_key = "tp" cls.device_mesh_key = "tp"
mesh = init_device_mesh( cls.mesh = init_device_mesh(
"cuda", (world_size,), mesh_dim_names=(device_mesh_key,) "cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
) )
return device_mesh_key, mesh
def create_test_params_batch(self, model, num_params=64): def create_test_params_batch(self, model, num_params=64):
"""Create a batch of test parameters from the model""" """Create a batch of test parameters from the model"""
param_names = [] param_names = []
@@ -143,31 +145,27 @@ class TestUtilsUpdateWeights:
return list(zip(param_names, test_tensors)) return list(zip(param_names, test_tensors))
@pytest.mark.asyncio def test_utils_update_weights(self):
async def test_utils_update_weights(
self, setup_distributed, test_engine, test_model, device_mesh
):
"""Test basic functionality of utils.update_weights""" """Test basic functionality of utils.update_weights"""
rank, world_size = setup_distributed
device_mesh_key, mesh = device_mesh
# Create test parameters batch async def async_test():
params_batch = self.create_test_params_batch(test_model, num_params=2) # Create test parameters batch
params_batch = self.create_test_params_batch(self.model, num_params=2)
print( # Test the utils.update_weights function
f"Rank {rank} testing utils.update_weights with {len(params_batch)} parameters" result = await update_weights(
) engine=self.engine,
# Test the utils.update_weights function params_batch=params_batch,
result = await update_weights( device_mesh_key=self.device_mesh_key,
engine=test_engine, device_mesh=self.mesh,
params_batch=params_batch, load_format=None,
device_mesh_key=device_mesh_key, )
device_mesh=mesh,
load_format=None,
)
assert "Success" in result self.assertIn("Success", result)
# Run the async test
asyncio.run(async_test())
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) unittest.main()