[verl] Modify the update_weights func to align with verl's resharding (#5345)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
BearBiscuit
2025-04-17 10:56:57 +08:00
committed by GitHub
parent 177320a582
commit 90faf9018e

View File

@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -124,7 +124,7 @@ class VerlEngine:
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
named_tensors: Iterable[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
):
# Most naive implementation, can optimize a lot if it is bottleneck
@@ -153,9 +153,12 @@ class VerlEngine:
)
],
load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1,
flush_cache=False,
)
if self._tp_rank == 0:
self._engine.tokenizer_manager.flush_cache()
def release_memory_occupation(self):
if self._tp_rank == 0:
self._engine.release_memory_occupation()