[feat] Reduce GPU memory overhead by using weakref (#9673)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Set, Type
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
|
||||
penalizers: Set[Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self._batch_ref = weakref.ref(batch)
|
||||
self.device = batch.device
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
||||
|
||||
@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
@property
|
||||
def batch(self) -> ScheduleBatch | None:
|
||||
return self._batch_ref()
|
||||
|
||||
@batch.setter
|
||||
def batch(self, value: Optional[ScheduleBatch]):
|
||||
if value is None:
|
||||
self._batch_ref = lambda: None
|
||||
else:
|
||||
self._batch_ref = weakref.ref(value)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user