Fix stop_profile does not wait for finishing (#4741)
This commit is contained in:
@@ -321,7 +321,8 @@ class Engine(EngineBase):
|
||||
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
||||
|
||||
def stop_profile(self):
|
||||
self.tokenizer_manager.stop_profile()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.tokenizer_manager.stop_profile())
|
||||
|
||||
def get_server_info(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -355,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
||||
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||
async def stop_profile_async():
|
||||
"""Stop profiling."""
|
||||
_global_state.tokenizer_manager.stop_profile()
|
||||
await _global_state.tokenizer_manager.stop_profile()
|
||||
return Response(
|
||||
content="Stop profiling. This will take some time.\n",
|
||||
status_code=200,
|
||||
|
||||
@@ -1512,7 +1512,7 @@ class Scheduler(
|
||||
self.profiler_target_forward_ct
|
||||
and self.profiler_target_forward_ct <= self.forward_ct
|
||||
):
|
||||
self.stop_profile()
|
||||
self.send_to_tokenizer.send_pyobj(self.stop_profile())
|
||||
|
||||
if self.forward_sleep_time is not None:
|
||||
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
||||
@@ -2114,7 +2114,10 @@ class Scheduler(
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler_activities is None:
|
||||
return
|
||||
return ProfileReqOutput(
|
||||
success=False,
|
||||
message="Profiling is not in progress. Call /start_profile first.",
|
||||
)
|
||||
|
||||
logger.info("Stop profiling...")
|
||||
if self.torch_profiler is not None:
|
||||
@@ -2145,10 +2148,7 @@ class Scheduler(
|
||||
self.torch_profiler_output_dir = None
|
||||
self.profiler_activities = None
|
||||
|
||||
if self.profiler_target_forward_ct:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
ProfileReqOutput(success=True, message="Succeeded.")
|
||||
)
|
||||
return ProfileReqOutput(success=True, message="Succeeded")
|
||||
|
||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||
|
||||
@@ -295,7 +295,7 @@ class TokenizerManager:
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.start_profile_communicator = _Communicator(
|
||||
self.profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||
@@ -360,7 +360,7 @@ class TokenizerManager:
|
||||
),
|
||||
(
|
||||
ProfileReqOutput,
|
||||
self.start_profile_communicator.handle_recv,
|
||||
self.profile_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetInternalStateReqOutput,
|
||||
@@ -801,7 +801,14 @@ class TokenizerManager:
|
||||
record_shapes=record_shapes,
|
||||
profile_id=str(time.time()),
|
||||
)
|
||||
result = (await self.start_profile_communicator(req))[0]
|
||||
return await self._execute_profile(req)
|
||||
|
||||
async def stop_profile(self):
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
return await self._execute_profile(req)
|
||||
|
||||
async def _execute_profile(self, req: ProfileReq):
|
||||
result = (await self.profile_communicator(req))[0]
|
||||
if not result.success:
|
||||
raise RuntimeError(result.message)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user