From 501dfa6b42a9b7f21756bb21952d685edbdadc07 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 7 Oct 2025 21:34:25 +0800 Subject: [PATCH] Remove sampling info events and overlap thread file (#11300) --- python/sglang/srt/disaggregation/decode.py | 13 - python/sglang/srt/disaggregation/prefill.py | 15 - python/sglang/srt/managers/schedule_batch.py | 1 - python/sglang/srt/managers/scheduler.py | 25 +- .../scheduler_output_processor_mixin.py | 3 - .../srt/managers/tp_worker_overlap_thread.py | 307 ------------------ .../srt/model_executor/forward_batch_info.py | 7 - .../sglang/srt/model_executor/model_runner.py | 14 +- .../srt/sampling/sampling_batch_info.py | 21 +- 9 files changed, 13 insertions(+), 393 deletions(-) delete mode 100644 python/sglang/srt/managers/tp_worker_overlap_thread.py diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index cf87d62d7..fa3b2bc1f 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin: self.prepare_mlp_sync_batch(batch) result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) - - if (self.last_batch is None) or (not self.last_batch_in_queue): - # Create a dummy first batch to start the pipeline for overlap schedule. - # It is now used for triggering the sampling_info_done event. - tmp_batch = ScheduleBatch( - reqs=None, - forward_mode=ForwardMode.DUMMY_FIRST, - next_batch_sampling_info=self.tp_worker.cur_sampling_info, - ) - self.set_next_batch_sampling_info_done(tmp_batch) last_batch_in_queue = True elif prepare_mlp_sync_flag: @@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin: # Process the results of the previous batch but skip if the last batch is extend if self.last_batch and self.last_batch_in_queue: tmp_batch, tmp_result = self.result_queue.popleft() - tmp_batch.next_batch_sampling_info = ( - self.tp_worker.cur_sampling_info if batch else None - ) self.process_batch_result(tmp_batch, tmp_result) queue_size = ( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index b761ad7ac..020d3f5aa 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin: result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) - if self.last_batch is None: - # Create a dummy first batch to start the pipeline for overlap schedule. - # It is now used for triggering the sampling_info_done event. - tmp_batch = ScheduleBatch( - reqs=None, - forward_mode=ForwardMode.DUMMY_FIRST, - next_batch_sampling_info=self.tp_worker.cur_sampling_info, - ) - self.set_next_batch_sampling_info_done(tmp_batch) - if self.last_batch: tmp_batch, tmp_result = self.result_queue.popleft() - tmp_batch.next_batch_sampling_info = ( - self.tp_worker.cur_sampling_info if batch else None - ) self.process_batch_result_disagg_prefill(tmp_batch, tmp_result) if len(self.disagg_prefill_inflight_queue) > 0: @@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin: if self.enable_overlap: self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) - # We need to remove the sync in the following function for overlap schedule. - self.set_next_batch_sampling_info_done(batch) self.maybe_send_health_check_signal() def process_disagg_prefill_inflight_queue( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2403070e3..cfd607cc5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Sampling info sampling_info: SamplingBatchInfo = None - next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner input_ids: torch.Tensor = None # shape: [b], int64 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a6b8b0b11..68203d51e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1012,22 +1012,9 @@ class Scheduler( result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) - if self.last_batch is None: - # Create a dummy first batch to start the pipeline for overlap schedule. - # It is now used for triggering the sampling_info_done event. - tmp_batch = ScheduleBatch( - reqs=None, - forward_mode=ForwardMode.DUMMY_FIRST, - next_batch_sampling_info=self.tp_worker.cur_sampling_info, - ) - self.process_batch_result(tmp_batch, None) - if self.last_batch: # Process the results of the last batch tmp_batch, tmp_result = self.result_queue.popleft() - tmp_batch.next_batch_sampling_info = ( - self.tp_worker.cur_sampling_info if batch else None - ) self.process_batch_result(tmp_batch, tmp_result) elif batch is None: # When the server is idle, do self-check and re-init some states @@ -2100,7 +2087,7 @@ class Scheduler( self.record_batch_in_overlap(model_worker_batch) # Sampling info will be modified during forward - model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = ( + model_worker_batch.sampling_info = ( model_worker_batch.sampling_info.copy_for_forward() ) @@ -2219,9 +2206,6 @@ class Scheduler( if self.enable_overlap: if result.copy_done is not None: result.copy_done.synchronize() - self.set_next_batch_sampling_info_done(batch) - elif batch.forward_mode.is_dummy_first(): - self.set_next_batch_sampling_info_done(batch) self.maybe_send_health_check_signal() @@ -2431,13 +2415,6 @@ class Scheduler( self._add_request_to_queue(req) self.grammar_queue = self.grammar_queue[num_ready_reqs:] - def set_next_batch_sampling_info_done(self, batch: ScheduleBatch): - if batch.next_batch_sampling_info: - if batch.next_batch_sampling_info.grammars is not None: - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.default_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() - def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" self.watchdog_last_forward_ct = 0 diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 5a14ba4fa..b31bf92a7 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin: ) logprob_pt += num_input_logprobs - self.set_next_batch_sampling_info_done(batch) - else: # embedding or reward model embeddings = result.embeddings.tolist() @@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin: self.abort_request(AbortReq(rid=req.rid)) req.grammar.finished = req.finished() - self.set_next_batch_sampling_info_done(batch) self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool_allocator.free_group_end() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py deleted file mode 100644 index 3491dce7d..000000000 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A tensor parallel worker.""" -from __future__ import annotations - -import dataclasses -import logging -import signal -import threading -from queue import Queue -from typing import TYPE_CHECKING, List, Optional, Tuple - -import psutil -import torch - -from sglang.srt.managers.io_struct import ( - DestroyWeightsUpdateGroupReqInput, - GetWeightsByNameReqInput, - InitWeightsSendGroupForRemoteInstanceReqInput, - InitWeightsUpdateGroupReqInput, - LoadLoRAAdapterReqInput, - SendWeightsToRemoteInstanceReqInput, - UnloadLoRAAdapterReqInput, - UpdateWeightFromDiskReqInput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, -) -from sglang.srt.managers.overlap_utils import FutureMap -from sglang.srt.managers.schedule_batch import ModelWorkerBatch -from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import DynamicGradMode -from sglang.utils import get_exception_traceback - -if TYPE_CHECKING: - from sglang.srt.managers.cache_controller import LayerDoneCounter - -logger = logging.getLogger(__name__) - - -class TpModelWorkerClient: - """A tensor parallel model worker.""" - - def __init__( - self, - server_args: ServerArgs, - gpu_id: int, - tp_rank: int, - moe_ep_rank: int, - pp_rank: int, - dp_rank: Optional[int], - nccl_port: int, - ): - # Load the model - self.worker = TpModelWorker( - server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port - ) - self.max_running_requests = self.worker.max_running_requests - self.device = self.worker.device - self.gpu_id = gpu_id - - # Init future mappings - self.future_map = FutureMap(self.max_running_requests, self.device) - - # Launch threads - self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]() - self.output_queue = Queue() - self.forward_stream = torch.get_device_module(self.device).Stream() - self.forward_thread = threading.Thread( - target=self.forward_thread_func, - ) - self.forward_thread.start() - self.parent_process = psutil.Process().parent() - self.scheduler_stream = torch.get_device_module(self.device).current_stream() - if self.device == "cpu": - self.scheduler_stream.synchronize = lambda: None # No-op for CPU - - self.hicache_layer_transfer_counter = None - - def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): - self.hicache_layer_transfer_counter = counter - - def get_worker_info(self): - return self.worker.get_worker_info() - - def get_tokens_per_layer_info(self): - return self.worker.get_tokens_per_layer_info() - - @property - def sliding_window_size(self) -> Optional[int]: - return self.worker.sliding_window_size - - @property - def is_hybrid(self) -> bool: - return self.worker.is_hybrid - - def get_pad_input_ids_func(self): - return self.worker.get_pad_input_ids_func() - - def get_tp_group(self): - return self.worker.get_tp_group() - - def get_attention_tp_group(self): - return self.worker.get_attention_tp_group() - - def get_attention_tp_cpu_group(self): - return self.worker.get_attention_tp_cpu_group() - - def get_memory_pool(self): - return ( - self.worker.model_runner.req_to_token_pool, - self.worker.model_runner.token_to_kv_pool_allocator, - ) - - def get_kv_cache(self): - return self.worker.model_runner.token_to_kv_pool - - def forward_thread_func(self): - try: - with torch.get_device_module(self.device).stream(self.forward_stream): - self.forward_thread_func_() - except Exception: - traceback = get_exception_traceback() - logger.error(f"TpModelWorkerClient hit an exception: {traceback}") - self.parent_process.send_signal(signal.SIGQUIT) - - @DynamicGradMode() - def forward_thread_func_(self): - batch_pt = 0 - batch_lists: List = [None] * 2 - - while True: - model_worker_batch, future_map_ct, sync_event = self.input_queue.get() - if not model_worker_batch: - break - - sync_event.wait() - - # Keep a reference of model_worker_batch by storing it into a list. - # Otherwise, the tensor members of model_worker_batch will be released - # by pytorch and cause CUDA illegal memory access errors. - batch_lists[batch_pt % 2] = model_worker_batch - batch_pt += 1 - - # Create event - copy_done = torch.get_device_module(self.device).Event() - - # Resolve future tokens in the input - self.future_map.resolve_future(model_worker_batch) - - # Run forward - forward_batch_output = self.worker.forward_batch_generation( - model_worker_batch, - model_worker_batch.launch_done, - ) - - logits_output, next_token_ids, can_run_cuda_graph = ( - forward_batch_output.logits_output, - forward_batch_output.next_token_ids, - forward_batch_output.can_run_cuda_graph, - ) - - # Update the future token ids map - bs = len(model_worker_batch.seq_lens) - if model_worker_batch.is_prefill_only: - # For prefill-only requests, create dummy token IDs on CPU - next_token_ids = torch.zeros(bs, dtype=torch.long) - - # store the future indices into future map - self.future_map.store_to_map(future_map_ct, bs, next_token_ids) - - # Copy results to the CPU - if model_worker_batch.return_logprob: - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.to("cpu", non_blocking=True) - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = ( - logits_output.input_token_logprobs.to("cpu", non_blocking=True) - ) - if logits_output.hidden_states is not None: - logits_output.hidden_states = logits_output.hidden_states.to( - "cpu", non_blocking=True - ) - # Only copy to CPU if not already on CPU - if next_token_ids.device.type != "cpu": - next_token_ids = next_token_ids.to("cpu", non_blocking=True) - copy_done.record() - - self.output_queue.put( - (copy_done, logits_output, next_token_ids, can_run_cuda_graph) - ) - - def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): - """ - This function is called to resolve the last batch result and - wait for the current batch to be launched. Used in overlap mode. - """ - copy_done, logits_output, next_token_ids, can_run_cuda_graph = ( - self.output_queue.get() - ) - - if launch_done is not None: - launch_done.wait() - copy_done.synchronize() - - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = tuple( - logits_output.input_token_logprobs.tolist() - ) - next_token_ids = next_token_ids.tolist() - return logits_output, next_token_ids, can_run_cuda_graph - - def forward_batch_generation( - self, model_worker_batch: ModelWorkerBatch - ) -> ForwardBatchOutput: - # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. - model_worker_batch.sampling_info = self.cur_sampling_info = ( - model_worker_batch.sampling_info.copy_for_forward() - ) - - # A cuda stream sync here to avoid the cuda illegal memory access error. - sync_event = torch.get_device_module(self.device).Event() - sync_event.record(self.scheduler_stream) - - # Push a new batch to the queue - bs = len(model_worker_batch.seq_lens) - cur_future_map_ct = self.future_map.update_ct(bs) - self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event)) - - # get this forward batch's future token ids - future_next_token_ids = self.future_map.update_next_future( - cur_future_map_ct, bs - ) - return ForwardBatchOutput( - next_token_ids=future_next_token_ids, - can_run_cuda_graph=False, - ) - - def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): - success, message = self.worker.update_weights_from_disk(recv_req) - return success, message - - def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): - success, message = self.worker.init_weights_update_group(recv_req) - return success, message - - def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput): - success, message = self.worker.destroy_weights_update_group(recv_req) - return success, message - - def init_weights_send_group_for_remote_instance( - self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput - ): - success, message = self.worker.init_weights_send_group_for_remote_instance( - recv_req - ) - return success, message - - def send_weights_to_remote_instance( - self, recv_req: SendWeightsToRemoteInstanceReqInput - ): - success, message = self.worker.send_weights_to_remote_instance(recv_req) - return success, message - - def update_weights_from_distributed( - self, recv_req: UpdateWeightsFromDistributedReqInput - ): - success, message = self.worker.update_weights_from_distributed(recv_req) - return success, message - - def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): - success, message = self.worker.update_weights_from_tensor(recv_req) - return success, message - - def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): - return self.worker.get_weights_by_name(recv_req) - - def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): - return self.worker.load_lora_adapter(recv_req) - - def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): - return self.worker.unload_lora_adapter(recv_req) - - def can_run_lora_batch(self, lora_ids: list[str]) -> bool: - return self.worker.can_run_lora_batch(lora_ids) - - def __delete__(self): - self.input_queue.put((None, None)) - self.copy_queue.put((None, None, None)) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a9a29fbc7..e16458e02 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -75,10 +75,6 @@ class ForwardMode(IntEnum): # Used in speculative decoding: extend a batch in the draft model. DRAFT_EXTEND = auto() - # A dummy first batch to start the pipeline for overlap scheduler. - # It is now used for triggering the sampling_info_done event for the first prefill batch. - DUMMY_FIRST = auto() - # Split Prefill for PD multiplexing SPLIT_PREFILL = auto() @@ -128,9 +124,6 @@ class ForwardMode(IntEnum): def is_cpu_graph(self): return self == ForwardMode.DECODE - def is_dummy_first(self): - return self == ForwardMode.DUMMY_FIRST - def is_split_prefill(self): return self == ForwardMode.SPLIT_PREFILL diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 83f8c8046..e4569ed20 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2057,15 +2057,11 @@ class ModelRunner: def _preprocess_logits( self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo ): - # Apply logit bias - if sampling_info.sampling_info_done: - # Overlap mode: the function update_regex_vocab_mask was executed - # in process_batch_result of the last batch. - if sampling_info.grammars: - sampling_info.sampling_info_done.wait() - else: - # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. - sampling_info.update_regex_vocab_mask() + # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample) + # was executed after we processed last batch's results. + + # Calculate logits bias and apply it to next_token_logits. + sampling_info.update_regex_vocab_mask() sampling_info.apply_logits_bias(logits_output.next_token_logits) def sample( diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index d246ac3c3..d636ccdd0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -44,12 +44,9 @@ class SamplingBatchInfo: vocab_mask: Optional[torch.Tensor] = None apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None - # An event used for overlap schedule - sampling_info_done: Optional[threading.Event] = None - # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None - linear_penalty: torch.Tensor = None + acc_linear_penalties: torch.Tensor = None # Used in the overlap mode # Whether any request has custom logit processor has_custom_logit_processor: bool = False @@ -217,19 +214,19 @@ class SamplingBatchInfo: def update_penalties(self): if self.penalizer_orchestrator.is_required: - self.linear_penalty = torch.zeros( + self.acc_linear_penalties = torch.zeros( (len(self.temperatures), self.vocab_size), dtype=torch.float32, device=self.temperatures.device, ) - self.penalizer_orchestrator.apply(self.linear_penalty) + self.penalizer_orchestrator.apply(self.acc_linear_penalties) else: - self.linear_penalty = None + self.acc_linear_penalties = None def apply_logits_bias(self, logits: torch.Tensor): - if self.linear_penalty is not None: + if self.acc_linear_penalties is not None: # Used in the overlap mode - logits.add_(self.linear_penalty) + logits.add_(self.acc_linear_penalties) if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: # Used in the non-overlap mode @@ -373,11 +370,7 @@ class SamplingBatchInfo: def copy_for_forward(self): # Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later self.update_penalties() - return dataclasses.replace( - self, - sampling_info_done=threading.Event(), - penalizer_orchestrator=None, - ) + return dataclasses.replace(self, penalizer_orchestrator=None) def merge_bias_tensor(