Support updating expert locations dynamically (#6388)

This commit is contained in:
fzyzcjy
2025-05-22 12:59:33 +08:00
committed by GitHub
parent 121f92c583
commit fc992a09f9
5 changed files with 723 additions and 0 deletions

View File

@@ -0,0 +1,255 @@
import os
import traceback
import unittest
from dataclasses import dataclass
from typing import List
import torch
import torch.distributed
import torch.multiprocessing as mp
from torch.multiprocessing import Process
from sglang.srt.model_executor import expert_location_updater
from sglang.test.test_utils import CustomTestCase, find_available_port
from sglang.utils import is_in_ci
@dataclass
class _TestInfo:
nnodes: int
num_logical_experts: int
num_physical_experts: int
num_repeat: int = 5000
class TestExpertLocationUpdater(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def test_cpu(self):
self._test_common(device="cpu")
self._test_core(
num_gpus=32,
device="cpu",
infos=[
_TestInfo(
nnodes=4,
num_logical_experts=256,
num_physical_experts=288,
num_repeat=10000,
)
],
)
def test_cpu_slow(self):
if is_in_ci():
return
self._test_core(
num_gpus=144,
device="cpu",
infos=[
_TestInfo(
nnodes=18,
num_logical_experts=256,
num_physical_experts=288,
num_repeat=10000,
)
],
)
def test_gpu(self):
if is_in_ci():
return
self._test_common(device="cuda")
def _test_common(self, device):
infos = []
for nnodes in [1, 2, 4]:
for num_logical_experts in [2, 5, 20, 256]:
for num_physical_experts in [8, 16, 256, 288]:
if num_logical_experts > num_physical_experts:
continue
infos.append(
_TestInfo(
nnodes=nnodes,
num_logical_experts=num_logical_experts,
num_physical_experts=num_physical_experts,
)
)
self._test_core(num_gpus=8, device=device, infos=infos)
def _test_core(
self,
num_gpus: int,
device: str,
infos: List[_TestInfo],
):
master_port = find_available_port(23456)
processes = []
output_reader, output_writer = mp.Pipe(duplex=False)
for rank in range(num_gpus):
p = Process(
target=_run_subprocess,
kwargs=dict(
rank=rank,
num_gpus=num_gpus,
output_writer=output_writer,
master_port=master_port,
device=device,
infos=infos,
),
)
p.start()
processes.append(p)
for _ in range(num_gpus):
self.assertTrue(
output_reader.recv(), f"Subprocess has error, please see logs above."
)
for p in processes:
p.join()
def _run_subprocess(
rank: int,
num_gpus: int,
master_port: int,
device: str,
infos: List[_TestInfo],
output_writer,
):
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
torch.random.manual_seed(42)
torch.distributed.init_process_group(
rank=rank,
world_size=num_gpus,
backend={"cpu": "gloo", "cuda": None}[device],
)
if device == "cuda":
torch.cuda.set_device(f"cuda:{rank}")
for info in infos:
_execute_test(info, rank=rank, num_gpus=num_gpus, device=device)
execution_ok = True
except Exception as e:
print(f"subprocess[{rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
def _execute_test(info: _TestInfo, rank: int, num_gpus: int, device: str):
if rank == 0:
print(f"Test: {num_gpus=} {info=}", flush=True)
assert info.num_physical_experts % num_gpus == 0
num_local_physical_experts = info.num_physical_experts // num_gpus
assert num_gpus % info.nnodes == 0
num_gpu_per_node = num_gpus // info.nnodes
def _create_routed_experts_weights(physical_to_logical_map):
local_logical_expert_ids = physical_to_logical_map[
rank * num_local_physical_experts : (rank + 1) * num_local_physical_experts
].cpu()
return [
local_logical_expert_ids.to(device).clone(),
torch.tensor(
[
[local_logical_expert_id * 10, local_logical_expert_id * 100]
for local_logical_expert_id in local_logical_expert_ids.tolist()
],
device=device,
),
]
def _create_physical_to_logical_map():
if rank == 0:
ans = torch.concat(
[
torch.arange(0, info.num_logical_experts),
torch.randint(
0,
info.num_logical_experts,
(info.num_physical_experts - info.num_logical_experts,),
),
]
)
ans = ans[torch.randperm(ans.shape[0])]
else:
ans = torch.empty((info.num_physical_experts,), dtype=torch.int64)
assert ans.dtype == torch.int64 and ans.shape == (info.num_physical_experts,)
ans = ans.to(device)
torch.distributed.broadcast(ans, src=0)
return ans.cpu()
physical_to_logical_map = _create_physical_to_logical_map()
routed_experts_weights = _create_routed_experts_weights(physical_to_logical_map)
for i in range(info.num_repeat):
if rank == 0 and ((i % 500 == 0) or (i == info.num_repeat - 1)):
print(f"Step {i}/{info.num_repeat}", flush=True)
new_physical_to_logical_map = _create_physical_to_logical_map()
expect_new_weights = _create_routed_experts_weights(new_physical_to_logical_map)
output_logs = expert_location_updater.update_expert_weights_single_layer(
routed_experts_weights=routed_experts_weights,
temp_buffers=expert_location_updater.create_temp_buffers(
routed_experts_weights
),
old_physical_to_logical_map=physical_to_logical_map,
new_physical_to_logical_map=new_physical_to_logical_map,
num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node,
rank=rank,
debug=True,
)
local_has_error = not all(
torch.all(x == y)
for x, y in zip(routed_experts_weights, expect_new_weights, strict=True)
)
global_has_error = torch.tensor(local_has_error, device=device)
torch.distributed.all_reduce(
global_has_error, op=torch.distributed.ReduceOp.MAX
)
if global_has_error.cpu().item():
output_logs_str = "\n".join(output_logs)
local_message = (
f"===================== rank {rank} ============================\n"
f"{num_gpus=} {info=}\n"
f"{routed_experts_weights[0].tolist()=}\n"
f"{expect_new_weights[0].tolist()=}\n"
f"{physical_to_logical_map.tolist()=}\n"
f"{new_physical_to_logical_map.tolist()=}\n"
f"===logs===\n"
f"{output_logs_str}\n"
f"==============================================================\n"
)
global_messages = ([None] * num_gpus) if rank == 0 else None
torch.distributed.gather_object(local_message, global_messages, dst=0)
if rank == 0:
print("\n\n".join(global_messages), flush=True)
raise AssertionError(f"Error happens, see logs above")
physical_to_logical_map = new_physical_to_logical_map
if __name__ == "__main__":
unittest.main()