[RL] add pause and continue generation for async rl training (#7419)
This commit is contained in:
@@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
|
|||||||
return ORJSONResponse(content=response_data, status_code=200)
|
return ORJSONResponse(content=response_data, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/pause_generation")
|
||||||
|
async def pause_generation(request: Request):
|
||||||
|
"""Pause generation."""
|
||||||
|
await _global_state.tokenizer_manager.pause_generation()
|
||||||
|
return ORJSONResponse(
|
||||||
|
content={"message": "Generation paused successfully.", "status": "ok"},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/continue_generation")
|
||||||
|
async def continue_generation(request: Request):
|
||||||
|
"""Continue generation."""
|
||||||
|
await _global_state.tokenizer_manager.continue_generation()
|
||||||
|
return ORJSONResponse(
|
||||||
|
content={"message": "Generation continued successfully.", "status": "ok"},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
##### OpenAI-compatible API endpoints #####
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -203,6 +203,8 @@ class TokenizerManager:
|
|||||||
self.is_image_gen = self.model_config.is_image_gen
|
self.is_image_gen = self.model_config.is_image_gen
|
||||||
self.context_len = self.model_config.context_len
|
self.context_len = self.model_config.context_len
|
||||||
self.image_token_id = self.model_config.image_token_id
|
self.image_token_id = self.model_config.image_token_id
|
||||||
|
self._updating = False
|
||||||
|
self._cond = asyncio.Condition()
|
||||||
|
|
||||||
if self.model_config.is_multimodal:
|
if self.model_config.is_multimodal:
|
||||||
import_processors()
|
import_processors()
|
||||||
@@ -421,6 +423,9 @@ class TokenizerManager:
|
|||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
created_time = time.time()
|
created_time = time.time()
|
||||||
|
async with self._cond:
|
||||||
|
await self._cond.wait_for(lambda: not self._updating)
|
||||||
|
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
obj.normalize_batch_and_arguments()
|
obj.normalize_batch_and_arguments()
|
||||||
|
|
||||||
@@ -902,6 +907,16 @@ class TokenizerManager:
|
|||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||||
|
|
||||||
|
async def pause_generation(self):
|
||||||
|
async with self._cond:
|
||||||
|
self._updating = True
|
||||||
|
self.abort_request(abort_all=True)
|
||||||
|
|
||||||
|
async def continue_generation(self):
|
||||||
|
async with self._cond:
|
||||||
|
self._updating = False
|
||||||
|
self._cond.notify_all()
|
||||||
|
|
||||||
async def update_weights_from_disk(
|
async def update_weights_from_disk(
|
||||||
self,
|
self,
|
||||||
obj: UpdateWeightFromDiskReqInput,
|
obj: UpdateWeightFromDiskReqInput,
|
||||||
|
|||||||
Reference in New Issue
Block a user