196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
import threading
|
|
|
|
import torch
|
|
|
|
from vllm import forward_context
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.utils import current_stream
|
|
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
|
|
|
|
|
class SUPAUBatchContext:
|
|
"""
|
|
Context manager for micro-batching synchronization using threading events.
|
|
"""
|
|
|
|
def __init__(self,
|
|
id: int,
|
|
comm_stream: torch.supa.Stream,
|
|
compute_stream: torch.supa.Stream,
|
|
forward_context: ForwardContext,
|
|
ready_barrier: threading.Barrier,
|
|
cpu_wait_event: threading.Event,
|
|
cpu_signal_event: threading.Event,
|
|
gpu_comm_done_event: torch.supa.Event,
|
|
gpu_compute_done_event: torch.supa.Event,
|
|
schedule: str = "default"):
|
|
self.id = id
|
|
self.comm_stream = comm_stream
|
|
self.compute_stream = compute_stream
|
|
self.forward_context = forward_context
|
|
self.ready_barrier = ready_barrier
|
|
self.cpu_wait_event = cpu_wait_event
|
|
self.cpu_signal_event = cpu_signal_event
|
|
self.current_stream = compute_stream
|
|
self.gpu_comm_done_event = gpu_comm_done_event
|
|
self.gpu_compute_done_event = gpu_compute_done_event
|
|
self.schedule = schedule
|
|
self.recv_hook = None
|
|
|
|
def __enter__(self):
|
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
|
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
|
_CURRENT_CONTEXTS[self.id] = self
|
|
self.ready_barrier.wait()
|
|
|
|
self.cpu_wait_event.wait()
|
|
self.cpu_wait_event.clear()
|
|
self._restore_context()
|
|
# Assume we want to start on the compute stream
|
|
self.update_stream(self.compute_stream)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
|
_CURRENT_CONTEXTS[self.id] = None
|
|
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
|
self.maybe_run_recv_hook()
|
|
self.cpu_signal_event.set()
|
|
self.cpu_wait_event.clear()
|
|
return False
|
|
|
|
def _restore_context(self):
|
|
forward_context._forward_context = self.forward_context
|
|
|
|
def update_stream(self, stream):
|
|
self.current_stream = stream
|
|
if current_stream() != self.current_stream:
|
|
torch.supa.set_stream(self.current_stream)
|
|
|
|
def _signal_comm_done(self):
|
|
self.gpu_comm_done_event.record(self.comm_stream)
|
|
|
|
def _signal_compute_done(self):
|
|
self.gpu_compute_done_event.record(self.compute_stream)
|
|
|
|
def _wait_compute_done(self):
|
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
|
|
|
def _wait_comm_done(self):
|
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
|
|
|
def _cpu_yield(self):
|
|
# It is critical for correctness that only one thread is running
|
|
# at a time. These asserts just make sure that this is the only
|
|
# thread running before waking the other one up and going to sleep
|
|
assert forward_context._forward_context == self.forward_context
|
|
assert current_stream() == self.current_stream
|
|
assert not self.cpu_wait_event.is_set()
|
|
|
|
self.cpu_signal_event.set()
|
|
self.cpu_wait_event.wait()
|
|
self.cpu_wait_event.clear()
|
|
self._restore_context()
|
|
|
|
def switch_to_comm(self):
|
|
self.update_stream(self.comm_stream)
|
|
|
|
def switch_to_compute(self):
|
|
self.update_stream(self.compute_stream)
|
|
|
|
def switch_to_comm_sync(self):
|
|
self._signal_compute_done()
|
|
self.update_stream(self.comm_stream)
|
|
self._wait_compute_done()
|
|
|
|
def switch_to_compute_sync(self):
|
|
self._signal_comm_done()
|
|
self.update_stream(self.compute_stream)
|
|
self._wait_comm_done()
|
|
|
|
def maybe_run_recv_hook(self):
|
|
if self.recv_hook is not None:
|
|
self.recv_hook()
|
|
self.recv_hook = None
|
|
|
|
def yield_(self):
|
|
self.current_stream = current_stream()
|
|
self._cpu_yield()
|
|
self.update_stream(self.current_stream)
|
|
|
|
def yield_and_switch_from_compute_to_comm(self):
|
|
assert current_stream() == self.compute_stream
|
|
self._signal_compute_done()
|
|
self._cpu_yield()
|
|
assert self.current_stream == self.compute_stream
|
|
self.update_stream(self.comm_stream)
|
|
self._wait_compute_done()
|
|
|
|
def yield_and_switch_from_comm_to_compute(self):
|
|
assert current_stream() == self.comm_stream
|
|
self._signal_comm_done()
|
|
self._cpu_yield()
|
|
assert self.current_stream == self.comm_stream
|
|
self.update_stream(self.compute_stream)
|
|
self._wait_comm_done()
|
|
|
|
|
|
def supa_make_ubatch_contexts(
|
|
num_micro_batches: int,
|
|
compute_stream: torch.supa.Stream,
|
|
comm_stream: torch.supa.Stream,
|
|
forward_contexts: list[ForwardContext],
|
|
ready_barrier: threading.Barrier,
|
|
schedule: str = "default",
|
|
) -> list[UBatchContext]:
|
|
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
|
"""
|
|
Create a context manager for micro-batching synchronization.
|
|
"""
|
|
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
|
gpu_comm_done_events = [
|
|
torch.supa.Event() for _ in range(num_micro_batches)
|
|
]
|
|
gpu_compute_done_events = [
|
|
torch.supa.Event() for _ in range(num_micro_batches)
|
|
]
|
|
|
|
assert len(forward_contexts) == 2
|
|
|
|
ctxs = []
|
|
for i in range(num_micro_batches):
|
|
ctx = UBatchContext(id=i,
|
|
compute_stream=compute_stream,
|
|
comm_stream=comm_stream,
|
|
forward_context=forward_contexts[i],
|
|
ready_barrier=ready_barrier,
|
|
cpu_wait_event=cpu_events[i],
|
|
cpu_signal_event=cpu_events[(i + 1) %
|
|
num_micro_batches],
|
|
gpu_comm_done_event=gpu_comm_done_events[i],
|
|
gpu_compute_done_event=gpu_compute_done_events[i],
|
|
schedule=schedule)
|
|
ctxs.append(ctx)
|
|
|
|
return ctxs
|
|
|
|
|
|
UBatchContext = SUPAUBatchContext
|
|
make_ubatch_contexts = supa_make_ubatch_contexts
|