Fix test_expert_distribution failure (#4752)
This commit is contained in:
@@ -346,7 +346,7 @@ async def stop_profile_async():
|
||||
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def start_expert_distribution_record_async():
|
||||
"""Start recording the expert distribution. Clear the previous record if any."""
|
||||
_global_state.tokenizer_manager.start_expert_distribution_record()
|
||||
await _global_state.tokenizer_manager.start_expert_distribution_record()
|
||||
return Response(
|
||||
content="Start recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
@@ -356,7 +356,7 @@ async def start_expert_distribution_record_async():
|
||||
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def stop_expert_distribution_record_async():
|
||||
"""Stop recording the expert distribution."""
|
||||
_global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||
await _global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||
return Response(
|
||||
content="Stop recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
@@ -366,7 +366,7 @@ async def stop_expert_distribution_record_async():
|
||||
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def dump_expert_distribution_record_async():
|
||||
"""Dump expert distribution record."""
|
||||
_global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||
await _global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||
return Response(
|
||||
content="Dump expert distribution record.\n",
|
||||
status_code=200,
|
||||
|
||||
@@ -664,6 +664,11 @@ class ExpertDistributionReq(Enum):
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertDistributionReqOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
|
||||
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
@@ -1905,6 +1906,7 @@ class Scheduler(
|
||||
expert_distribution_recorder.dump_record()
|
||||
else:
|
||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||
return ExpertDistributionReqOutput()
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
|
||||
@@ -61,6 +61,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
@@ -264,6 +265,9 @@ class TokenizerManager:
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -313,6 +317,10 @@ class TokenizerManager:
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
@@ -639,17 +647,14 @@ class TokenizerManager:
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.START_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def start_expert_distribution_record(self):
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
||||
|
||||
def stop_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.STOP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def stop_expert_distribution_record(self):
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
||||
|
||||
def dump_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.DUMP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def dump_expert_distribution_record(self):
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
|
||||
@@ -28,9 +28,13 @@ class TestExpertDistribution(unittest.TestCase):
|
||||
def test_expert_distribution_record(self):
|
||||
"""Test expert distribution record endpoints"""
|
||||
process = popen_launch_server(
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
|
||||
# The feature is only implemented in deepseek_v2.py
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -68,7 +72,9 @@ class TestExpertDistribution(unittest.TestCase):
|
||||
# Verify the dumped file exists and has correct format
|
||||
csv_files = glob.glob("expert_distribution_*.csv")
|
||||
self.assertEqual(
|
||||
len(csv_files), 1, "Expected exactly one expert distribution CSV file"
|
||||
len(csv_files),
|
||||
1,
|
||||
f"Expected exactly one expert distribution CSV file {csv_files=}",
|
||||
)
|
||||
|
||||
# Check CSV file format
|
||||
@@ -97,11 +103,17 @@ class TestExpertDistribution(unittest.TestCase):
|
||||
|
||||
# Verify data types
|
||||
layer_id, expert_id, count = row
|
||||
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
|
||||
self.assertTrue(
|
||||
expert_id.isdigit(), "expert_id should be an integer"
|
||||
layer_id.isdigit(),
|
||||
f"layer_id should be an integer {row=} {rows=}",
|
||||
)
|
||||
self.assertTrue(
|
||||
expert_id.isdigit(),
|
||||
f"expert_id should be an integer {row=} {rows=}",
|
||||
)
|
||||
self.assertTrue(
|
||||
count.isdigit(), f"count should be an integer {row=} {rows=}"
|
||||
)
|
||||
self.assertTrue(count.isdigit(), "count should be an integer")
|
||||
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
Reference in New Issue
Block a user