Expert distribution recording without overhead for EPLB (#4957)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
|
||||
|
||||
|
||||
class TestExpertDistribution(CustomTestCase):
|
||||
def setUp(self):
|
||||
# Clean up any existing expert distribution files before each test
|
||||
for f in glob.glob("expert_distribution_*.csv"):
|
||||
os.remove(f)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up any expert distribution files after each test
|
||||
for f in glob.glob("expert_distribution_*.csv"):
|
||||
os.remove(f)
|
||||
|
||||
def test_expert_distribution_record(self):
|
||||
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
|
||||
for info in [
|
||||
dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
|
||||
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
|
||||
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
|
||||
# TODO enable in next PR
|
||||
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
|
||||
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
|
||||
]:
|
||||
with self.subTest(info=info):
|
||||
self._execute_core(**info)
|
||||
|
||||
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
|
||||
"""Test expert distribution record endpoints"""
|
||||
process = popen_launch_server(
|
||||
# The feature is only implemented in deepseek_v2.py
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
],
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
|
||||
|
||||
try:
|
||||
# Start recording
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Make some requests to generate expert distribution data
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Stop recording
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Dump the recorded data
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Verify the dumped file exists and has correct format
|
||||
csv_files = glob.glob("expert_distribution_*.csv")
|
||||
self.assertEqual(
|
||||
len(csv_files),
|
||||
1,
|
||||
f"Expected exactly one expert distribution CSV file {csv_files=}",
|
||||
process = popen_launch_server(
|
||||
model_path,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp-size",
|
||||
str(tp_size),
|
||||
"--expert-distribution-recorder-mode",
|
||||
mode,
|
||||
"--disable-cuda-graph",
|
||||
"--disable-overlap-schedule",
|
||||
],
|
||||
)
|
||||
|
||||
# Check CSV file format
|
||||
with open(csv_files[0], "r") as f:
|
||||
csv_reader = csv.reader(f)
|
||||
|
||||
# Check header
|
||||
header = next(csv_reader)
|
||||
self.assertEqual(
|
||||
header,
|
||||
["layer_id", "expert_id", "count"],
|
||||
"CSV header should be 'layer_id,expert_id,count'",
|
||||
try:
|
||||
# Start recording
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Make some requests to generate expert distribution data
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Stop recording
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Dump the recorded data
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Check data rows
|
||||
rows = list(csv_reader)
|
||||
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
|
||||
data = torch.load(
|
||||
list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
|
||||
)
|
||||
print(f"{data=}")
|
||||
|
||||
for row in rows:
|
||||
# Verify each row has 3 columns
|
||||
self.assertEqual(
|
||||
len(row),
|
||||
3,
|
||||
"Each row should have layer_id, expert_id and count",
|
||||
)
|
||||
if mode in ["per_pass", "per_token"]:
|
||||
self.assertGreater(len(data), 0, "Should contain data rows")
|
||||
else:
|
||||
logical_count = data["logical_count"]
|
||||
print(f"{logical_count.sum()=} {logical_count=}")
|
||||
self.assertTrue(logical_count.sum() > 0)
|
||||
|
||||
# Verify data types
|
||||
layer_id, expert_id, count = row
|
||||
self.assertTrue(
|
||||
layer_id.isdigit(),
|
||||
f"layer_id should be an integer {row=} {rows=}",
|
||||
)
|
||||
self.assertTrue(
|
||||
expert_id.isdigit(),
|
||||
f"expert_id should be an integer {row=} {rows=}",
|
||||
)
|
||||
self.assertTrue(
|
||||
count.isdigit(), f"count should be an integer {row=} {rows=}"
|
||||
)
|
||||
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user