diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py new file mode 100644 index 000000000..d512ae7ec --- /dev/null +++ b/python/sglang/srt/managers/overlap_utils.py @@ -0,0 +1,53 @@ +import torch + +from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.utils import get_compiler_backend + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def _resolve_future_token_ids(input_ids, future_token_ids_map): + input_ids[:] = torch.where( + input_ids < 0, + future_token_ids_map[torch.clamp(-input_ids, min=0)], + input_ids, + ) + + +class FutureMap: + def __init__( + self, + max_running_requests: int, + device: torch.device, + ): + self.future_ct = 0 + # A factor of 3 is used to avoid collision in the circular buffer. + self.future_limit = max_running_requests * 3 + # A factor of 5 is used to ensure the buffer is large enough. + self.future_buffer_len = max_running_requests * 5 + self.device = device + + self.token_ids_buf = torch.empty( + (self.future_buffer_len,), dtype=torch.int64, device=self.device + ) + + def update_ct(self, bs: int) -> int: + """Update the circular buffer pointer and return the current pointer.""" + cur_future_ct = self.future_ct + self.future_ct = (cur_future_ct + bs) % self.future_limit + return cur_future_ct + + def resolve_future(self, model_worker_batch: ModelWorkerBatch): + input_ids = model_worker_batch.input_ids + _resolve_future_token_ids(input_ids, self.token_ids_buf) + + def update_next_future(self, future_ct: int, bs: int): + return torch.arange( + -(future_ct + 1), + -(future_ct + 1 + bs), + -1, + dtype=torch.int64, + device=self.device, + ) + + def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor): + self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index d0b5e586d..9ca68b0b8 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) +from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import DynamicGradMode, get_compiler_backend +from sglang.srt.utils import DynamicGradMode from sglang.utils import get_exception_traceback if TYPE_CHECKING: @@ -48,15 +49,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def resolve_future_token_ids(input_ids, future_token_ids_map): - input_ids[:] = torch.where( - input_ids < 0, - future_token_ids_map[torch.clamp(-input_ids, min=0)], - input_ids, - ) - - class TpModelWorkerClient: """A tensor parallel model worker.""" @@ -79,11 +71,7 @@ class TpModelWorkerClient: self.gpu_id = gpu_id # Init future mappings - self.future_token_ids_ct = 0 - self.future_token_ids_limit = self.max_running_requests * 3 - self.future_token_ids_map = torch.empty( - (self.max_running_requests * 5,), dtype=torch.int64, device=self.device - ) + self.future_map = FutureMap(self.max_running_requests, self.device) # Launch threads self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]() @@ -153,7 +141,7 @@ class TpModelWorkerClient: batch_lists: List = [None] * 2 while True: - model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get() + model_worker_batch, future_map_ct, sync_event = self.input_queue.get() if not model_worker_batch: break @@ -169,8 +157,7 @@ class TpModelWorkerClient: copy_done = torch.get_device_module(self.device).Event() # Resolve future tokens in the input - input_ids = model_worker_batch.input_ids - resolve_future_token_ids(input_ids, self.future_token_ids_map) + self.future_map.resolve_future(model_worker_batch) # Run forward logits_output, next_token_ids, can_run_cuda_graph = ( @@ -187,9 +174,9 @@ class TpModelWorkerClient: if model_worker_batch.is_prefill_only: # For prefill-only requests, create dummy token IDs on CPU next_token_ids = torch.zeros(bs, dtype=torch.long) - self.future_token_ids_map[ - future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 - ] = next_token_ids + + # store the future indices into future map + self.future_map.store_to_map(future_map_ct, bs, next_token_ids) # Copy results to the CPU if model_worker_batch.return_logprob: @@ -255,20 +242,14 @@ class TpModelWorkerClient: sync_event.record(self.scheduler_stream) # Push a new batch to the queue - self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event)) - - # Allocate output future objects bs = len(model_worker_batch.seq_lens) - future_next_token_ids = torch.arange( - -(self.future_token_ids_ct + 1), - -(self.future_token_ids_ct + 1 + bs), - -1, - dtype=torch.int64, - device=self.device, + cur_future_map_ct = self.future_map.update_ct(bs) + self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event)) + + # get this forward batch's future token ids + future_next_token_ids = self.future_map.update_next_future( + cur_future_map_ct, bs ) - self.future_token_ids_ct = ( - self.future_token_ids_ct + bs - ) % self.future_token_ids_limit return None, future_next_token_ids, False def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):