sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
170
test/srt/test_utils_update_weights.py
Normal file
170
test/srt/test_utils_update_weights.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.weight_sync.utils import update_weights
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
class AsyncEngine(Engine):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def update_weights_from_tensor(self, update_weights_request):
|
||||
return await self.tokenizer_manager.update_weights_from_tensor(
|
||||
update_weights_request, None
|
||||
)
|
||||
|
||||
|
||||
def is_distributed_available():
|
||||
"""Check if distributed training environment is available"""
|
||||
required_vars = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
|
||||
return all(var in os.environ for var in required_vars)
|
||||
|
||||
|
||||
def setup_single_process_distributed():
|
||||
"""Setup distributed environment for single process testing"""
|
||||
if not is_distributed_available():
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12356"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
|
||||
|
||||
class TestUtilsUpdateWeights(unittest.TestCase):
|
||||
"""Test class for utils.update_weights function"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Setup distributed environment and test fixtures for the entire test class"""
|
||||
cls.setup_distributed()
|
||||
cls.setup_test_engine()
|
||||
cls.setup_test_model()
|
||||
cls.setup_device_mesh()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Cleanup after all tests"""
|
||||
if hasattr(cls, "engine") and cls.engine:
|
||||
cls.engine.shutdown()
|
||||
|
||||
# Cleanup distributed
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@classmethod
|
||||
def setup_distributed(cls):
|
||||
"""Setup distributed environment for testing"""
|
||||
setup_single_process_distributed()
|
||||
|
||||
if not dist.is_initialized():
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo"
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.SkipTest(
|
||||
f"Could not initialize distributed backend: {e}"
|
||||
)
|
||||
|
||||
cls.rank = dist.get_rank()
|
||||
cls.world_size = dist.get_world_size()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(cls.rank % torch.cuda.device_count())
|
||||
|
||||
# Set up environment variables
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||
|
||||
@classmethod
|
||||
def setup_test_engine(cls):
|
||||
"""Setup test engine"""
|
||||
if cls.rank == 0:
|
||||
cls.engine = AsyncEngine(
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
dtype="bfloat16",
|
||||
mem_fraction_static=0.3,
|
||||
enable_memory_saver=True,
|
||||
tp_size=cls.world_size,
|
||||
disable_cuda_graph=False,
|
||||
)
|
||||
else:
|
||||
cls.engine = None
|
||||
|
||||
@classmethod
|
||||
def setup_test_model(cls):
|
||||
"""Load test model"""
|
||||
try:
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
device_map="cpu",
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=(
|
||||
torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.SkipTest(f"Could not load test model: {e}")
|
||||
|
||||
@classmethod
|
||||
def setup_device_mesh(cls):
|
||||
"""Create device mesh for testing"""
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA not available for device mesh")
|
||||
|
||||
cls.device_mesh_key = "tp"
|
||||
cls.mesh = init_device_mesh(
|
||||
"cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
|
||||
)
|
||||
|
||||
def create_test_params_batch(self, model, num_params=64):
|
||||
"""Create a batch of test parameters from the model"""
|
||||
param_names = []
|
||||
test_tensors = []
|
||||
|
||||
# Get first few parameters from the model for testing
|
||||
for i, (name, tensor) in enumerate(model.named_parameters()):
|
||||
if i >= num_params:
|
||||
break
|
||||
param_names.append(name)
|
||||
# Create test tensor with known values, matching original shape and dtype
|
||||
test_tensor = torch.full_like(tensor, 1.5, dtype=tensor.dtype).cuda()
|
||||
test_tensors.append(test_tensor)
|
||||
|
||||
return list(zip(param_names, test_tensors))
|
||||
|
||||
def test_utils_update_weights(self):
|
||||
"""Test basic functionality of utils.update_weights"""
|
||||
|
||||
async def async_test():
|
||||
# Create test parameters batch
|
||||
params_batch = self.create_test_params_batch(self.model, num_params=2)
|
||||
|
||||
# Test the utils.update_weights function
|
||||
result = await update_weights(
|
||||
engine=self.engine,
|
||||
params_batch=params_batch,
|
||||
device_mesh_key=self.device_mesh_key,
|
||||
device_mesh=self.mesh,
|
||||
load_format=None,
|
||||
)
|
||||
|
||||
self.assertIn("Success", result)
|
||||
|
||||
# Run the async test
|
||||
asyncio.run(async_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user