70 lines
2.1 KiB
Python
70 lines
2.1 KiB
Python
import time
|
|
|
|
from vllm.v1.executor.abstract import logger, Executor
|
|
|
|
|
|
def is_offloaded(self) -> bool:
|
|
if not hasattr(self, "_is_offloaded"):
|
|
self._is_offloaded = False
|
|
return self._is_offloaded
|
|
|
|
def is_yielded(self) -> bool:
|
|
if not hasattr(self, "_is_yielded"):
|
|
self._is_yielded = False
|
|
return self._is_yielded
|
|
|
|
def offload_vram(self, is_yield: bool = False):
|
|
if self.is_offloaded():
|
|
logger.warning("Executor is already offloaded.")
|
|
return
|
|
time_before_offload = time.perf_counter()
|
|
self.collective_rpc("offload_vram")
|
|
time_after_offload = time.perf_counter()
|
|
|
|
self._is_offloaded = True
|
|
if is_yield:
|
|
self._is_yielded = True
|
|
logger.info(
|
|
f"Offloading VRAM costs {time_after_offload - time_before_offload:.3f} seconds."
|
|
)
|
|
|
|
def reload_vram(self) -> bool:
|
|
if not self.is_offloaded():
|
|
logger.warning("Executor is not offloaded.")
|
|
return True
|
|
|
|
is_waiting = False
|
|
while True:
|
|
time_before_reload = time.perf_counter()
|
|
res = self.collective_rpc("try_reload_vram")
|
|
time_after_reload = time.perf_counter()
|
|
|
|
succ = all(x[0] for x in res)
|
|
if succ:
|
|
self._is_offloaded = False
|
|
self._is_yielded = False
|
|
prev_is_self = all(x[1] for x in res)
|
|
if is_waiting:
|
|
self.collective_rpc("vnpu_cancel_wait")
|
|
logger.info(
|
|
f"Reloading VRAM costs {time_after_reload - time_before_reload:.3f} seconds."
|
|
)
|
|
return prev_is_self
|
|
else:
|
|
# some workers not get lock
|
|
if not is_waiting:
|
|
self.collective_rpc("vnpu_start_wait")
|
|
is_waiting = True
|
|
self.collective_rpc("vnpu_unlock_gpu", kwargs={"keep_wait": True})
|
|
time.sleep(0.001)
|
|
|
|
def vnpu_has_higher_priority_waiter(self) -> bool:
|
|
res = self.collective_rpc("vnpu_has_higher_priority_waiter")
|
|
return any(res)
|
|
|
|
|
|
Executor.is_offloaded = is_offloaded
|
|
Executor.offload_vram = offload_vram
|
|
Executor.reload_vram = reload_vram
|
|
Executor.vnpu_has_higher_priority_waiter = vnpu_has_higher_priority_waiter
|