From c377923304f56feb800de931aab63ac9d7de3c61 Mon Sep 17 00:00:00 2001 From: yhyang201 <47235274+yhyang201@users.noreply.github.com> Date: Thu, 28 Aug 2025 16:09:06 +0800 Subject: [PATCH] [feat] Reduce GPU memory overhead by using weakref (#9673) --- .../srt/sampling/penaltylib/orchestrator.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index a75d5e9bb..1abd255cb 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -1,7 +1,8 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Set, Type +import weakref +from typing import TYPE_CHECKING, Optional, Set, Type import torch @@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator: penalizers: Set[Type["_BatchedPenalizer"]], ): self.vocab_size = vocab_size - self.batch = batch + self._batch_ref = weakref.ref(batch) self.device = batch.device self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers} @@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator: is_required |= pen_is_required self.is_required = is_required + @property + def batch(self) -> ScheduleBatch | None: + return self._batch_ref() + + @batch.setter + def batch(self, value: Optional[ScheduleBatch]): + if value is None: + self._batch_ref = lambda: None + else: + self._batch_ref = weakref.ref(value) + def reqs(self): return self.batch.reqs