Extract update_weights from RL Engine to SGLang to keep simplicity and fix torch reduce (#8267)
Co-authored-by: CuiBo 82354186+SuperCB@users.noreply.github.com Co-authored-by: GeLee 865038696@qq.com Co-authored-by: 杨睿 yangruipis@163.com
This commit is contained in:
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
||||||
|
|
||||||
@@ -278,6 +279,8 @@ class TpModelWorker:
|
|||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
|
|
||||||
|
monkey_patch_torch_reductions()
|
||||||
success, message = self.model_runner.update_weights_from_tensor(
|
success, message = self.model_runner.update_weights_from_tensor(
|
||||||
named_tensors=MultiprocessingSerializer.deserialize(
|
named_tensors=MultiprocessingSerializer.deserialize(
|
||||||
recv_req.serialized_named_tensors[self.tp_rank]
|
recv_req.serialized_named_tensors[self.tp_rank]
|
||||||
|
|||||||
119
python/sglang/srt/weight_sync/utils.py
Normal file
119
python/sglang/srt/weight_sync/utils.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
|
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
|
||||||
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||||
|
from sglang.srt.utils import MultiprocessingSerializer
|
||||||
|
|
||||||
|
|
||||||
|
async def update_weights(
|
||||||
|
engine: Engine,
|
||||||
|
params_batch: list[tuple[str, torch.Tensor]],
|
||||||
|
device_mesh_key: str,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update weights for the inference engine.
|
||||||
|
This function is designed to be stateless, so that the caller process could keep the stateful engine.
|
||||||
|
Example Use Case:
|
||||||
|
- Multiple Producer Process will call this function in a SPMD style
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: The inference engine created by the caller process.
|
||||||
|
params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call.
|
||||||
|
device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp"
|
||||||
|
device_mesh: The device mesh.
|
||||||
|
load_format: The format of the weights.
|
||||||
|
"""
|
||||||
|
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
||||||
|
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
||||||
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
|
|
||||||
|
monkey_patch_torch_reductions()
|
||||||
|
|
||||||
|
# [
|
||||||
|
# (name0, ipc_tensor0_tp0),
|
||||||
|
# (name1, ipc_tensor1_tp0),
|
||||||
|
# ]
|
||||||
|
named_tensors_batch = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
MultiprocessingSerializer.serialize(
|
||||||
|
_preprocess_tensor_for_update_weights(tensor)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, tensor in params_batch
|
||||||
|
]
|
||||||
|
|
||||||
|
if infer_tp_rank == 0:
|
||||||
|
gathered_serialized_batches = [None for _ in range(infer_tp_size)]
|
||||||
|
else:
|
||||||
|
gathered_serialized_batches = None
|
||||||
|
|
||||||
|
# [
|
||||||
|
# [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ],
|
||||||
|
# [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ],
|
||||||
|
# ]
|
||||||
|
dist.gather_object(
|
||||||
|
obj=named_tensors_batch,
|
||||||
|
object_gather_list=gathered_serialized_batches,
|
||||||
|
dst=device_mesh[device_mesh_key].mesh.tolist()[0],
|
||||||
|
group=device_mesh[device_mesh_key].get_group(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if infer_tp_rank == 0:
|
||||||
|
# Use zip(*) to "transpose" the data structure.
|
||||||
|
# After transpose, the data structure is like:
|
||||||
|
# [
|
||||||
|
# ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ),
|
||||||
|
# ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ),
|
||||||
|
# ]
|
||||||
|
logical_tensors = zip(*gathered_serialized_batches, strict=True)
|
||||||
|
|
||||||
|
named_tensors = [
|
||||||
|
# [
|
||||||
|
# (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])),
|
||||||
|
# (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])),
|
||||||
|
# ]
|
||||||
|
(
|
||||||
|
tensor_group[0][0],
|
||||||
|
LocalSerializedTensor(
|
||||||
|
values=[rank_part[1] for rank_part in tensor_group]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tensor_group in logical_tensors
|
||||||
|
]
|
||||||
|
|
||||||
|
update_weights_request = UpdateWeightsFromTensorReqInput(
|
||||||
|
serialized_named_tensors=[
|
||||||
|
MultiprocessingSerializer.serialize(named_tensors)
|
||||||
|
for _ in range(infer_tp_size)
|
||||||
|
],
|
||||||
|
load_format=load_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await engine.update_weights_from_tensor(update_weights_request)
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Preprocess the tensor for update weights.
|
||||||
|
Example Use Case:
|
||||||
|
- FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights
|
||||||
|
- Megatron: we do nothing here, assuming it is gathered when feed into this func
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to be preprocessed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full tensor if it is a DTensor, otherwise the original tensor.
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, DTensor):
|
||||||
|
return tensor.full_tensor()
|
||||||
|
return tensor
|
||||||
@@ -101,6 +101,7 @@ suites = {
|
|||||||
TestFile("test_triton_sliding_window.py", 250),
|
TestFile("test_triton_sliding_window.py", 250),
|
||||||
TestFile("test_update_weights_from_disk.py", 114),
|
TestFile("test_update_weights_from_disk.py", 114),
|
||||||
TestFile("test_update_weights_from_tensor.py", 48),
|
TestFile("test_update_weights_from_tensor.py", 48),
|
||||||
|
TestFile("test_utils_update_weights.py", 48),
|
||||||
TestFile("test_vertex_endpoint.py", 31),
|
TestFile("test_vertex_endpoint.py", 31),
|
||||||
TestFile("test_vision_chunked_prefill.py", 175),
|
TestFile("test_vision_chunked_prefill.py", 175),
|
||||||
TestFile("test_vlm_input_format.py", 300),
|
TestFile("test_vlm_input_format.py", 300),
|
||||||
|
|||||||
173
test/srt/test_utils_update_weights.py
Normal file
173
test/srt/test_utils_update_weights.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from loguru import logger
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
|
from sglang.srt.weight_sync.utils import update_weights
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncEngine(Engine):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def update_weights_from_tensor(self, update_weights_request):
|
||||||
|
return await self.tokenizer_manager.update_weights_from_tensor(
|
||||||
|
update_weights_request, None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed_available():
|
||||||
|
"""Check if distributed training environment is available"""
|
||||||
|
required_vars = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
|
||||||
|
return all(var in os.environ for var in required_vars)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_single_process_distributed():
|
||||||
|
"""Setup distributed environment for single process testing"""
|
||||||
|
if not is_distributed_available():
|
||||||
|
os.environ["RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12356"
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtilsUpdateWeights:
|
||||||
|
"""Test class for utils.update_weights function"""
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def setup_distributed(self):
|
||||||
|
"""Setup distributed environment for testing"""
|
||||||
|
setup_single_process_distributed()
|
||||||
|
|
||||||
|
if not dist.is_initialized():
|
||||||
|
try:
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl" if torch.cuda.is_available() else "gloo"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not initialize distributed backend: {e}")
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_device(rank % torch.cuda.device_count())
|
||||||
|
|
||||||
|
# Set up environment variables
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||||
|
|
||||||
|
yield rank, world_size
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def test_engine(self, setup_distributed):
|
||||||
|
"""Setup test engine"""
|
||||||
|
rank, world_size = setup_distributed
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||||
|
engine = AsyncEngine(
|
||||||
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
dtype="bfloat16",
|
||||||
|
mem_fraction_static=0.3,
|
||||||
|
enable_memory_saver=True,
|
||||||
|
tp_size=world_size,
|
||||||
|
disable_cuda_graph=True,
|
||||||
|
)
|
||||||
|
yield engine
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
else:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def test_model(self):
|
||||||
|
"""Load test model"""
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
device_map="cpu",
|
||||||
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=(
|
||||||
|
torch.float16 if torch.cuda.is_available() else torch.float32
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not load test model: {e}")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def device_mesh(self, setup_distributed):
|
||||||
|
"""Create device mesh for testing"""
|
||||||
|
rank, world_size = setup_distributed
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA not available for device mesh")
|
||||||
|
|
||||||
|
device_mesh_key = "tp"
|
||||||
|
mesh = init_device_mesh(
|
||||||
|
"cuda", (world_size,), mesh_dim_names=(device_mesh_key,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return device_mesh_key, mesh
|
||||||
|
|
||||||
|
def create_test_params_batch(self, model, num_params=64):
|
||||||
|
"""Create a batch of test parameters from the model"""
|
||||||
|
param_names = []
|
||||||
|
test_tensors = []
|
||||||
|
|
||||||
|
# Get first few parameters from the model for testing
|
||||||
|
for i, (name, tensor) in enumerate(model.named_parameters()):
|
||||||
|
if i >= num_params:
|
||||||
|
break
|
||||||
|
param_names.append(name)
|
||||||
|
# Create test tensor with known values, matching original shape and dtype
|
||||||
|
test_tensor = torch.full_like(tensor, 1.5, dtype=tensor.dtype).cuda()
|
||||||
|
test_tensors.append(test_tensor)
|
||||||
|
|
||||||
|
return list(zip(param_names, test_tensors))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_utils_update_weights(
|
||||||
|
self, setup_distributed, test_engine, test_model, device_mesh
|
||||||
|
):
|
||||||
|
"""Test basic functionality of utils.update_weights"""
|
||||||
|
rank, world_size = setup_distributed
|
||||||
|
device_mesh_key, mesh = device_mesh
|
||||||
|
|
||||||
|
# Create test parameters batch
|
||||||
|
params_batch = self.create_test_params_batch(test_model, num_params=2)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Rank {rank} testing utils.update_weights with {len(params_batch)} parameters"
|
||||||
|
)
|
||||||
|
# Test the utils.update_weights function
|
||||||
|
result = await update_weights(
|
||||||
|
engine=test_engine,
|
||||||
|
params_batch=params_batch,
|
||||||
|
device_mesh_key=device_mesh_key,
|
||||||
|
device_mesh=mesh,
|
||||||
|
load_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Success" in result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user