[RL] fix skip_server_warmup and rl health_generate logic (#8757)
This commit is contained in:
@@ -1172,6 +1172,8 @@ def _wait_and_warmup(
|
||||
pipe_finish_writer,
|
||||
):
|
||||
return
|
||||
else:
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
|
||||
|
||||
@@ -473,6 +473,7 @@ class Scheduler(
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
self.offload_tags = set()
|
||||
self.init_profier()
|
||||
|
||||
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
||||
@@ -1040,7 +1041,9 @@ class Scheduler(
|
||||
for recv_req in recv_reqs:
|
||||
# If it is a health check generation request and there are running requests, ignore it.
|
||||
if is_health_check_generate_req(recv_req) and (
|
||||
self.chunked_req is not None or not self.running_batch.is_empty()
|
||||
self.chunked_req is not None
|
||||
or not self.running_batch.is_empty()
|
||||
or len(self.offload_tags) > 0
|
||||
):
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
|
||||
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
|
||||
if tags is None or len(tags) == 0:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
for tag in tags:
|
||||
self.offload_tags.add(tag)
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
self.flush_cache()
|
||||
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
|
||||
if tags is None or len(tags) == 0:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
for tag in tags:
|
||||
self.offload_tags.remove(tag)
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
torch.distributed.barrier(self.tp_cpu_group)
|
||||
|
||||
Reference in New Issue
Block a user