Add tensor.detach() back to update weight util (#8691)
This commit is contained in:
@@ -45,7 +45,7 @@ async def update_weights(
|
|||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
MultiprocessingSerializer.serialize(
|
MultiprocessingSerializer.serialize(
|
||||||
_preprocess_tensor_for_update_weights(tensor)
|
_preprocess_tensor_for_update_weights(tensor.detach())
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for name, tensor in params_batch
|
for name, tensor in params_batch
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from loguru import logger
|
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
@@ -39,11 +38,29 @@ def setup_single_process_distributed():
|
|||||||
os.environ["LOCAL_RANK"] = "0"
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
|
|
||||||
class TestUtilsUpdateWeights:
|
class TestUtilsUpdateWeights(unittest.TestCase):
|
||||||
"""Test class for utils.update_weights function"""
|
"""Test class for utils.update_weights function"""
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@classmethod
|
||||||
def setup_distributed(self):
|
def setUpClass(cls):
|
||||||
|
"""Setup distributed environment and test fixtures for the entire test class"""
|
||||||
|
cls.setup_distributed()
|
||||||
|
cls.setup_test_engine()
|
||||||
|
cls.setup_test_model()
|
||||||
|
cls.setup_device_mesh()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
"""Cleanup after all tests"""
|
||||||
|
if hasattr(cls, "engine") and cls.engine:
|
||||||
|
cls.engine.shutdown()
|
||||||
|
|
||||||
|
# Cleanup distributed
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_distributed(cls):
|
||||||
"""Setup distributed environment for testing"""
|
"""Setup distributed environment for testing"""
|
||||||
setup_single_process_distributed()
|
setup_single_process_distributed()
|
||||||
|
|
||||||
@@ -53,13 +70,15 @@ class TestUtilsUpdateWeights:
|
|||||||
backend="nccl" if torch.cuda.is_available() else "gloo"
|
backend="nccl" if torch.cuda.is_available() else "gloo"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.skip(f"Could not initialize distributed backend: {e}")
|
raise unittest.SkipTest(
|
||||||
|
f"Could not initialize distributed backend: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
cls.rank = dist.get_rank()
|
||||||
world_size = dist.get_world_size()
|
cls.world_size = dist.get_world_size()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(rank % torch.cuda.device_count())
|
torch.cuda.set_device(cls.rank % torch.cuda.device_count())
|
||||||
|
|
||||||
# Set up environment variables
|
# Set up environment variables
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
@@ -68,38 +87,26 @@ class TestUtilsUpdateWeights:
|
|||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||||
|
|
||||||
yield rank, world_size
|
@classmethod
|
||||||
|
def setup_test_engine(cls):
|
||||||
# Cleanup
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
|
||||||
def test_engine(self, setup_distributed):
|
|
||||||
"""Setup test engine"""
|
"""Setup test engine"""
|
||||||
rank, world_size = setup_distributed
|
if cls.rank == 0:
|
||||||
|
cls.engine = AsyncEngine(
|
||||||
if rank == 0:
|
|
||||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
|
||||||
engine = AsyncEngine(
|
|
||||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
mem_fraction_static=0.3,
|
mem_fraction_static=0.3,
|
||||||
enable_memory_saver=True,
|
enable_memory_saver=True,
|
||||||
tp_size=world_size,
|
tp_size=cls.world_size,
|
||||||
disable_cuda_graph=True,
|
disable_cuda_graph=False,
|
||||||
)
|
)
|
||||||
yield engine
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
yield None
|
cls.engine = None
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@classmethod
|
||||||
def test_model(self):
|
def setup_test_model(cls):
|
||||||
"""Load test model"""
|
"""Load test model"""
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
device_map="cpu",
|
device_map="cpu",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@@ -108,25 +115,20 @@ class TestUtilsUpdateWeights:
|
|||||||
torch.float16 if torch.cuda.is_available() else torch.float32
|
torch.float16 if torch.cuda.is_available() else torch.float32
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return model
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.skip(f"Could not load test model: {e}")
|
raise unittest.SkipTest(f"Could not load test model: {e}")
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@classmethod
|
||||||
def device_mesh(self, setup_distributed):
|
def setup_device_mesh(cls):
|
||||||
"""Create device mesh for testing"""
|
"""Create device mesh for testing"""
|
||||||
rank, world_size = setup_distributed
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
pytest.skip("CUDA not available for device mesh")
|
raise unittest.SkipTest("CUDA not available for device mesh")
|
||||||
|
|
||||||
device_mesh_key = "tp"
|
cls.device_mesh_key = "tp"
|
||||||
mesh = init_device_mesh(
|
cls.mesh = init_device_mesh(
|
||||||
"cuda", (world_size,), mesh_dim_names=(device_mesh_key,)
|
"cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
|
||||||
)
|
)
|
||||||
|
|
||||||
return device_mesh_key, mesh
|
|
||||||
|
|
||||||
def create_test_params_batch(self, model, num_params=64):
|
def create_test_params_batch(self, model, num_params=64):
|
||||||
"""Create a batch of test parameters from the model"""
|
"""Create a batch of test parameters from the model"""
|
||||||
param_names = []
|
param_names = []
|
||||||
@@ -143,31 +145,27 @@ class TestUtilsUpdateWeights:
|
|||||||
|
|
||||||
return list(zip(param_names, test_tensors))
|
return list(zip(param_names, test_tensors))
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_utils_update_weights(self):
|
||||||
async def test_utils_update_weights(
|
|
||||||
self, setup_distributed, test_engine, test_model, device_mesh
|
|
||||||
):
|
|
||||||
"""Test basic functionality of utils.update_weights"""
|
"""Test basic functionality of utils.update_weights"""
|
||||||
rank, world_size = setup_distributed
|
|
||||||
device_mesh_key, mesh = device_mesh
|
|
||||||
|
|
||||||
# Create test parameters batch
|
async def async_test():
|
||||||
params_batch = self.create_test_params_batch(test_model, num_params=2)
|
# Create test parameters batch
|
||||||
|
params_batch = self.create_test_params_batch(self.model, num_params=2)
|
||||||
|
|
||||||
print(
|
# Test the utils.update_weights function
|
||||||
f"Rank {rank} testing utils.update_weights with {len(params_batch)} parameters"
|
result = await update_weights(
|
||||||
)
|
engine=self.engine,
|
||||||
# Test the utils.update_weights function
|
params_batch=params_batch,
|
||||||
result = await update_weights(
|
device_mesh_key=self.device_mesh_key,
|
||||||
engine=test_engine,
|
device_mesh=self.mesh,
|
||||||
params_batch=params_batch,
|
load_format=None,
|
||||||
device_mesh_key=device_mesh_key,
|
)
|
||||||
device_mesh=mesh,
|
|
||||||
load_format=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Success" in result
|
self.assertIn("Success", result)
|
||||||
|
|
||||||
|
# Run the async test
|
||||||
|
asyncio.run(async_test())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user