[RL] Fix illegal memory for _import_static_state (#7733)
Co-authored-by: nanjiangwill <willjiang2018@gmail.com>
This commit is contained in:
@@ -2346,6 +2346,7 @@ class Scheduler(
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
torch.distributed.barrier(self.tp_cpu_group)
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
@@ -2357,6 +2358,7 @@ class Scheduler(
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
torch.distributed.barrier(self.tp_cpu_group)
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model,
|
||||
self.stashed_model_static_state,
|
||||
|
||||
Reference in New Issue
Block a user