[verl] Modify the update_weights func to align with verl's resharding (#5345)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -124,7 +124,7 @@ class VerlEngine:
|
||||
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||
named_tensors: Iterable[Tuple[str, torch.Tensor]],
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
# Most naive implementation, can optimize a lot if it is bottleneck
|
||||
@@ -153,9 +153,12 @@ class VerlEngine:
|
||||
)
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=tensor_index == len(named_tensors) - 1,
|
||||
flush_cache=False,
|
||||
)
|
||||
|
||||
if self._tp_rank == 0:
|
||||
self._engine.tokenizer_manager.flush_cache()
|
||||
|
||||
def release_memory_occupation(self):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.release_memory_occupation()
|
||||
|
||||
Reference in New Issue
Block a user