[feat] Reduce GPU memory overhead by using weakref (#9673)
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import TYPE_CHECKING, Set, Type
|
import weakref
|
||||||
|
from typing import TYPE_CHECKING, Optional, Set, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
|
|||||||
penalizers: Set[Type["_BatchedPenalizer"]],
|
penalizers: Set[Type["_BatchedPenalizer"]],
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.batch = batch
|
self._batch_ref = weakref.ref(batch)
|
||||||
self.device = batch.device
|
self.device = batch.device
|
||||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
||||||
|
|
||||||
@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
|
|||||||
is_required |= pen_is_required
|
is_required |= pen_is_required
|
||||||
self.is_required = is_required
|
self.is_required = is_required
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch(self) -> ScheduleBatch | None:
|
||||||
|
return self._batch_ref()
|
||||||
|
|
||||||
|
@batch.setter
|
||||||
|
def batch(self, value: Optional[ScheduleBatch]):
|
||||||
|
if value is None:
|
||||||
|
self._batch_ref = lambda: None
|
||||||
|
else:
|
||||||
|
self._batch_ref = weakref.ref(value)
|
||||||
|
|
||||||
def reqs(self):
|
def reqs(self):
|
||||||
return self.batch.reqs
|
return self.batch.reqs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user