[RL] Fix illegal memory for _import_static_state (#7733)
Co-authored-by: nanjiangwill <willjiang2018@gmail.com>
This commit is contained in:
@@ -2346,6 +2346,7 @@ class Scheduler(
|
|||||||
self.stashed_model_static_state = _export_static_state(
|
self.stashed_model_static_state = _export_static_state(
|
||||||
self.tp_worker.worker.model_runner.model
|
self.tp_worker.worker.model_runner.model
|
||||||
)
|
)
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
|
|
||||||
return ReleaseMemoryOccupationReqOutput()
|
return ReleaseMemoryOccupationReqOutput()
|
||||||
@@ -2357,6 +2358,7 @@ class Scheduler(
|
|||||||
|
|
||||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
_import_static_state(
|
_import_static_state(
|
||||||
self.tp_worker.worker.model_runner.model,
|
self.tp_worker.worker.model_runner.model,
|
||||||
self.stashed_model_static_state,
|
self.stashed_model_static_state,
|
||||||
|
|||||||
Reference in New Issue
Block a user