diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index a75d5e9bb..1abd255cb 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -1,7 +1,8 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Set, Type +import weakref +from typing import TYPE_CHECKING, Optional, Set, Type import torch @@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator: penalizers: Set[Type["_BatchedPenalizer"]], ): self.vocab_size = vocab_size - self.batch = batch + self._batch_ref = weakref.ref(batch) self.device = batch.device self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers} @@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator: is_required |= pen_is_required self.is_required = is_required + @property + def batch(self) -> ScheduleBatch | None: + return self._batch_ref() + + @batch.setter + def batch(self, value: Optional[ScheduleBatch]): + if value is None: + self._batch_ref = lambda: None + else: + self._batch_ref = weakref.ref(value) + def reqs(self): return self.batch.reqs