[verl] Modify the update_weights func to align with verl's resharding (#5345)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
BearBiscuit
2025-04-17 10:56:57 +08:00
committed by GitHub
parent 177320a582
commit 90faf9018e

View File

@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import os import os
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -124,7 +124,7 @@ class VerlEngine:
def update_weights_from_tensor( def update_weights_from_tensor(
self, self,
named_tensors: List[Tuple[str, torch.Tensor]], named_tensors: Iterable[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
# Most naive implementation, can optimize a lot if it is bottleneck # Most naive implementation, can optimize a lot if it is bottleneck
@@ -153,9 +153,12 @@ class VerlEngine:
) )
], ],
load_format=load_format, load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1, flush_cache=False,
) )
if self._tp_rank == 0:
self._engine.tokenizer_manager.flush_cache()
def release_memory_occupation(self): def release_memory_occupation(self):
if self._tp_rank == 0: if self._tp_rank == 0:
self._engine.release_memory_occupation() self._engine.release_memory_occupation()