Fix test_expert_distribution failure (#4752)
This commit is contained in:
@@ -346,7 +346,7 @@ async def stop_profile_async():
|
||||
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def start_expert_distribution_record_async():
|
||||
"""Start recording the expert distribution. Clear the previous record if any."""
|
||||
_global_state.tokenizer_manager.start_expert_distribution_record()
|
||||
await _global_state.tokenizer_manager.start_expert_distribution_record()
|
||||
return Response(
|
||||
content="Start recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
@@ -356,7 +356,7 @@ async def start_expert_distribution_record_async():
|
||||
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def stop_expert_distribution_record_async():
|
||||
"""Stop recording the expert distribution."""
|
||||
_global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||
await _global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||
return Response(
|
||||
content="Stop recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
@@ -366,7 +366,7 @@ async def stop_expert_distribution_record_async():
|
||||
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def dump_expert_distribution_record_async():
|
||||
"""Dump expert distribution record."""
|
||||
_global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||
await _global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||
return Response(
|
||||
content="Dump expert distribution record.\n",
|
||||
status_code=200,
|
||||
|
||||
@@ -664,6 +664,11 @@ class ExpertDistributionReq(Enum):
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertDistributionReqOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
|
||||
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
@@ -1905,6 +1906,7 @@ class Scheduler(
|
||||
expert_distribution_recorder.dump_record()
|
||||
else:
|
||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||
return ExpertDistributionReqOutput()
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
|
||||
@@ -61,6 +61,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
@@ -264,6 +265,9 @@ class TokenizerManager:
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -313,6 +317,10 @@ class TokenizerManager:
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
@@ -639,17 +647,14 @@ class TokenizerManager:
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.START_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def start_expert_distribution_record(self):
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
||||
|
||||
def stop_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.STOP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def stop_expert_distribution_record(self):
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
||||
|
||||
def dump_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.DUMP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def dump_expert_distribution_record(self):
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user