[verl] Modify the update_weights func to align with verl's resharding (#5345)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -124,7 +124,7 @@ class VerlEngine:
|
|||||||
|
|
||||||
def update_weights_from_tensor(
|
def update_weights_from_tensor(
|
||||||
self,
|
self,
|
||||||
named_tensors: List[Tuple[str, torch.Tensor]],
|
named_tensors: Iterable[Tuple[str, torch.Tensor]],
|
||||||
load_format: Optional[str] = None,
|
load_format: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# Most naive implementation, can optimize a lot if it is bottleneck
|
# Most naive implementation, can optimize a lot if it is bottleneck
|
||||||
@@ -153,9 +153,12 @@ class VerlEngine:
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
flush_cache=tensor_index == len(named_tensors) - 1,
|
flush_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._tp_rank == 0:
|
||||||
|
self._engine.tokenizer_manager.flush_cache()
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
def release_memory_occupation(self):
|
||||||
if self._tp_rank == 0:
|
if self._tp_rank == 0:
|
||||||
self._engine.release_memory_occupation()
|
self._engine.release_memory_occupation()
|
||||||
|
|||||||
Reference in New Issue
Block a user