[RL] fix skip_server_warmup and rl health_generate logic (#8757)
This commit is contained in:
@@ -1172,6 +1172,8 @@ def _wait_and_warmup(
|
|||||||
pipe_finish_writer,
|
pipe_finish_writer,
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
|
|
||||||
|
|||||||
@@ -473,6 +473,7 @@ class Scheduler(
|
|||||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=server_args.enable_memory_saver
|
enable=server_args.enable_memory_saver
|
||||||
)
|
)
|
||||||
|
self.offload_tags = set()
|
||||||
self.init_profier()
|
self.init_profier()
|
||||||
|
|
||||||
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
||||||
@@ -1040,7 +1041,9 @@ class Scheduler(
|
|||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
# If it is a health check generation request and there are running requests, ignore it.
|
# If it is a health check generation request and there are running requests, ignore it.
|
||||||
if is_health_check_generate_req(recv_req) and (
|
if is_health_check_generate_req(recv_req) and (
|
||||||
self.chunked_req is not None or not self.running_batch.is_empty()
|
self.chunked_req is not None
|
||||||
|
or not self.running_batch.is_empty()
|
||||||
|
or len(self.offload_tags) > 0
|
||||||
):
|
):
|
||||||
self.return_health_check_ct += 1
|
self.return_health_check_ct += 1
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
|
|||||||
if tags is None or len(tags) == 0:
|
if tags is None or len(tags) == 0:
|
||||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
self.offload_tags.add(tag)
|
||||||
|
|
||||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
|
|||||||
if tags is None or len(tags) == 0:
|
if tags is None or len(tags) == 0:
|
||||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
self.offload_tags.remove(tag)
|
||||||
|
|
||||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
|||||||
Reference in New Issue
Block a user