Support updating expert locations dynamically (#6388)
This commit is contained in:
255
test/srt/test_expert_location_updater.py
Normal file
255
test/srt/test_expert_location_updater.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import os
|
||||
import traceback
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
from sglang.srt.model_executor import expert_location_updater
|
||||
from sglang.test.test_utils import CustomTestCase, find_available_port
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestInfo:
|
||||
nnodes: int
|
||||
num_logical_experts: int
|
||||
num_physical_experts: int
|
||||
num_repeat: int = 5000
|
||||
|
||||
|
||||
class TestExpertLocationUpdater(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def test_cpu(self):
|
||||
self._test_common(device="cpu")
|
||||
self._test_core(
|
||||
num_gpus=32,
|
||||
device="cpu",
|
||||
infos=[
|
||||
_TestInfo(
|
||||
nnodes=4,
|
||||
num_logical_experts=256,
|
||||
num_physical_experts=288,
|
||||
num_repeat=10000,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def test_cpu_slow(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
self._test_core(
|
||||
num_gpus=144,
|
||||
device="cpu",
|
||||
infos=[
|
||||
_TestInfo(
|
||||
nnodes=18,
|
||||
num_logical_experts=256,
|
||||
num_physical_experts=288,
|
||||
num_repeat=10000,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def test_gpu(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
self._test_common(device="cuda")
|
||||
|
||||
def _test_common(self, device):
|
||||
infos = []
|
||||
|
||||
for nnodes in [1, 2, 4]:
|
||||
for num_logical_experts in [2, 5, 20, 256]:
|
||||
for num_physical_experts in [8, 16, 256, 288]:
|
||||
if num_logical_experts > num_physical_experts:
|
||||
continue
|
||||
infos.append(
|
||||
_TestInfo(
|
||||
nnodes=nnodes,
|
||||
num_logical_experts=num_logical_experts,
|
||||
num_physical_experts=num_physical_experts,
|
||||
)
|
||||
)
|
||||
|
||||
self._test_core(num_gpus=8, device=device, infos=infos)
|
||||
|
||||
def _test_core(
|
||||
self,
|
||||
num_gpus: int,
|
||||
device: str,
|
||||
infos: List[_TestInfo],
|
||||
):
|
||||
master_port = find_available_port(23456)
|
||||
|
||||
processes = []
|
||||
output_reader, output_writer = mp.Pipe(duplex=False)
|
||||
for rank in range(num_gpus):
|
||||
p = Process(
|
||||
target=_run_subprocess,
|
||||
kwargs=dict(
|
||||
rank=rank,
|
||||
num_gpus=num_gpus,
|
||||
output_writer=output_writer,
|
||||
master_port=master_port,
|
||||
device=device,
|
||||
infos=infos,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for _ in range(num_gpus):
|
||||
self.assertTrue(
|
||||
output_reader.recv(), f"Subprocess has error, please see logs above."
|
||||
)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
def _run_subprocess(
|
||||
rank: int,
|
||||
num_gpus: int,
|
||||
master_port: int,
|
||||
device: str,
|
||||
infos: List[_TestInfo],
|
||||
output_writer,
|
||||
):
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
|
||||
torch.random.manual_seed(42)
|
||||
torch.distributed.init_process_group(
|
||||
rank=rank,
|
||||
world_size=num_gpus,
|
||||
backend={"cpu": "gloo", "cuda": None}[device],
|
||||
)
|
||||
if device == "cuda":
|
||||
torch.cuda.set_device(f"cuda:{rank}")
|
||||
|
||||
for info in infos:
|
||||
_execute_test(info, rank=rank, num_gpus=num_gpus, device=device)
|
||||
|
||||
execution_ok = True
|
||||
except Exception as e:
|
||||
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
execution_ok = False
|
||||
|
||||
output_writer.send(execution_ok)
|
||||
output_writer.close()
|
||||
|
||||
|
||||
def _execute_test(info: _TestInfo, rank: int, num_gpus: int, device: str):
|
||||
if rank == 0:
|
||||
print(f"Test: {num_gpus=} {info=}", flush=True)
|
||||
|
||||
assert info.num_physical_experts % num_gpus == 0
|
||||
num_local_physical_experts = info.num_physical_experts // num_gpus
|
||||
assert num_gpus % info.nnodes == 0
|
||||
num_gpu_per_node = num_gpus // info.nnodes
|
||||
|
||||
def _create_routed_experts_weights(physical_to_logical_map):
|
||||
local_logical_expert_ids = physical_to_logical_map[
|
||||
rank * num_local_physical_experts : (rank + 1) * num_local_physical_experts
|
||||
].cpu()
|
||||
return [
|
||||
local_logical_expert_ids.to(device).clone(),
|
||||
torch.tensor(
|
||||
[
|
||||
[local_logical_expert_id * 10, local_logical_expert_id * 100]
|
||||
for local_logical_expert_id in local_logical_expert_ids.tolist()
|
||||
],
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
|
||||
def _create_physical_to_logical_map():
|
||||
if rank == 0:
|
||||
ans = torch.concat(
|
||||
[
|
||||
torch.arange(0, info.num_logical_experts),
|
||||
torch.randint(
|
||||
0,
|
||||
info.num_logical_experts,
|
||||
(info.num_physical_experts - info.num_logical_experts,),
|
||||
),
|
||||
]
|
||||
)
|
||||
ans = ans[torch.randperm(ans.shape[0])]
|
||||
else:
|
||||
ans = torch.empty((info.num_physical_experts,), dtype=torch.int64)
|
||||
|
||||
assert ans.dtype == torch.int64 and ans.shape == (info.num_physical_experts,)
|
||||
ans = ans.to(device)
|
||||
torch.distributed.broadcast(ans, src=0)
|
||||
|
||||
return ans.cpu()
|
||||
|
||||
physical_to_logical_map = _create_physical_to_logical_map()
|
||||
routed_experts_weights = _create_routed_experts_weights(physical_to_logical_map)
|
||||
|
||||
for i in range(info.num_repeat):
|
||||
if rank == 0 and ((i % 500 == 0) or (i == info.num_repeat - 1)):
|
||||
print(f"Step {i}/{info.num_repeat}", flush=True)
|
||||
|
||||
new_physical_to_logical_map = _create_physical_to_logical_map()
|
||||
expect_new_weights = _create_routed_experts_weights(new_physical_to_logical_map)
|
||||
|
||||
output_logs = expert_location_updater.update_expert_weights_single_layer(
|
||||
routed_experts_weights=routed_experts_weights,
|
||||
temp_buffers=expert_location_updater.create_temp_buffers(
|
||||
routed_experts_weights
|
||||
),
|
||||
old_physical_to_logical_map=physical_to_logical_map,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_gpu_per_node=num_gpu_per_node,
|
||||
rank=rank,
|
||||
debug=True,
|
||||
)
|
||||
|
||||
local_has_error = not all(
|
||||
torch.all(x == y)
|
||||
for x, y in zip(routed_experts_weights, expect_new_weights, strict=True)
|
||||
)
|
||||
global_has_error = torch.tensor(local_has_error, device=device)
|
||||
torch.distributed.all_reduce(
|
||||
global_has_error, op=torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
|
||||
if global_has_error.cpu().item():
|
||||
output_logs_str = "\n".join(output_logs)
|
||||
local_message = (
|
||||
f"===================== rank {rank} ============================\n"
|
||||
f"{num_gpus=} {info=}\n"
|
||||
f"{routed_experts_weights[0].tolist()=}\n"
|
||||
f"{expect_new_weights[0].tolist()=}\n"
|
||||
f"{physical_to_logical_map.tolist()=}\n"
|
||||
f"{new_physical_to_logical_map.tolist()=}\n"
|
||||
f"===logs===\n"
|
||||
f"{output_logs_str}\n"
|
||||
f"==============================================================\n"
|
||||
)
|
||||
|
||||
global_messages = ([None] * num_gpus) if rank == 0 else None
|
||||
torch.distributed.gather_object(local_message, global_messages, dst=0)
|
||||
|
||||
if rank == 0:
|
||||
print("\n\n".join(global_messages), flush=True)
|
||||
raise AssertionError(f"Error happens, see logs above")
|
||||
|
||||
physical_to_logical_map = new_physical_to_logical_map
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user