From 90faf9018eb2088d4beeb3d6821799125a07e14a Mon Sep 17 00:00:00 2001 From: BearBiscuit <55008898+BearBiscuit05@users.noreply.github.com> Date: Thu, 17 Apr 2025 10:56:57 +0800 Subject: [PATCH] [verl] Modify the update_weights func to align with verl's resharding (#5345) Co-authored-by: Chayenne --- python/sglang/srt/entrypoints/verl_engine.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index e1ce84731..11cb4a246 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== import os -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union import torch import torch.distributed as dist @@ -124,7 +124,7 @@ class VerlEngine: def update_weights_from_tensor( self, - named_tensors: List[Tuple[str, torch.Tensor]], + named_tensors: Iterable[Tuple[str, torch.Tensor]], load_format: Optional[str] = None, ): # Most naive implementation, can optimize a lot if it is bottleneck @@ -153,9 +153,12 @@ class VerlEngine: ) ], load_format=load_format, - flush_cache=tensor_index == len(named_tensors) - 1, + flush_cache=False, ) + if self._tp_rank == 0: + self._engine.tokenizer_manager.flush_cache() + def release_memory_occupation(self): if self._tp_rank == 0: self._engine.release_memory_occupation()