diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index eb27c5c25..9cc435ad8 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,4 @@ import logging -import os from typing import Union import torch diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b6a9be71b..615301154 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -136,6 +136,7 @@ class ImageInputs: image_embeds: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None + # QWen2-VL related image_grid_thws: List[Tuple[int, int, int]] = None mrope_position_delta: Optional[torch.Tensor] = None @@ -187,11 +188,10 @@ class Req: self.origin_input_ids = origin_input_ids self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids - self.sampling_params = sampling_params self.lora_path = lora_path - # Memory info + # Memory pool info self.req_pool_idx = None # Check finish @@ -428,7 +428,7 @@ bid = 0 @dataclasses.dataclass class ScheduleBatch: - """Store all inforamtion of a batch.""" + """Store all inforamtion of a batch on the scheduler.""" # Request, memory pool, and cache reqs: List[Req] @@ -438,9 +438,9 @@ class ScheduleBatch: # For utility model_config: ModelConfig = None - forward_mode: ForwardMode = None sampling_info: SamplingBatchInfo = None + next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner input_ids: torch.Tensor = None @@ -509,7 +509,7 @@ class ScheduleBatch: def is_empty(self): return len(self.reqs) == 0 - def alloc_req_slots(self, num_reqs): + def alloc_req_slots(self, num_reqs: int): req_pool_indices = self.req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: raise RuntimeError( @@ -610,7 +610,7 @@ class ScheduleBatch: assert len(self.out_cache_loc) == self.extend_num_tokens - def prepare_for_extend(self): + def prepare_for_extend(self, enable_overlap_schedule: bool = False): self.forward_mode = ForwardMode.EXTEND bs = len(self.reqs) @@ -704,7 +704,7 @@ class ScheduleBatch: self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, - global_server_args_dict["disable_penalizer"], + enable_overlap_schedule=enable_overlap_schedule, ) def mix_with_running(self, running_batch: "ScheduleBatch"): @@ -746,6 +746,7 @@ class ScheduleBatch: return False def retract_decode(self): + """Retract the decoding requests when there is not enough memory.""" sorted_indices = [i for i in range(len(self.reqs))] # TODO(lsyin): improve retraction policy for radix cache @@ -886,18 +887,10 @@ class ScheduleBatch: def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE - self.input_ids = torch.empty(0, dtype=torch.int32).to( - self.device, non_blocking=True - ) - self.seq_lens = torch.empty(0, dtype=torch.int32).to( - self.device, non_blocking=True - ) - self.out_cache_loc = torch.empty(0, dtype=torch.int32).to( - self.device, non_blocking=True - ) - self.req_pool_indices = torch.empty(0, dtype=torch.int32).to( - self.device, non_blocking=True - ) + self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) + self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 @@ -1063,7 +1056,6 @@ class ScheduleBatch: out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, - sampling_info=self.sampling_info, ) def __str__(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1df8499af..e555c0d94 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -15,6 +15,7 @@ limitations under the License. """A scheduler that manages a tensor parallel GPU worker.""" +import dataclasses import logging import os import threading @@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, @@ -220,8 +222,12 @@ class Scheduler: # Init running status self.waiting_queue: List[Req] = [] + # The running decoding batch for continuous batching self.running_batch: Optional[ScheduleBatch] = None + # The current forward batch self.cur_batch: Optional[ScheduleBatch] = None + # The current forward batch + self.last_batch: Optional[ScheduleBatch] = None self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 @@ -336,15 +342,12 @@ class Scheduler: @torch.no_grad() def event_loop_normal(self): - """A normal blocking scheduler loop.""" - self.last_batch = None - + """A normal scheduler loop.""" while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() - if self.server_args.enable_dp_attention: batch = self.prepare_dp_attn_batch(batch) @@ -353,20 +356,8 @@ class Scheduler: if batch: result = self.run_batch(batch) self.process_batch_result(batch, result) - - # Decode multiple steps to reduce the overhead - if batch.forward_mode.is_decode(): - for _ in range(self.server_args.num_continuous_decode_steps - 1): - if not self.running_batch: - break - self.update_running_batch() - if not self.running_batch: - break - if self.server_args.enable_dp_attention: - batch = self.prepare_dp_attn_batch(batch) - result = self.run_batch(batch) - self.process_batch_result(batch, result) else: + # Self-check and re-init some states when the server is idle self.check_memory() self.new_token_ratio = self.init_new_token_ratio @@ -377,9 +368,6 @@ class Scheduler: """A scheduler loop that overlaps the CPU processing and GPU computation.""" result_queue = deque() - self.last_batch = None - self.running_batch = None - while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -390,10 +378,24 @@ class Scheduler: result = self.run_batch(batch) result_queue.append((batch.copy(), result)) + if self.last_batch is None: + # A dummy first batch to start the pipeline for overlap scheduler. + # It is now used for triggering the sampling_info_done event. + tmp_batch = ScheduleBatch( + reqs=None, + forward_mode=ForwardMode.DUMMY_FIRST, + next_batch_sampling_info=self.tp_worker.cur_sampling_info, + ) + self.process_batch_result(tmp_batch, None) + if self.last_batch: tmp_batch, tmp_result = result_queue.popleft() + tmp_batch.next_batch_sampling_info = ( + self.tp_worker.cur_sampling_info if batch else None + ) self.process_batch_result(tmp_batch, tmp_result) elif batch is None: + # Self-check and re-init some states when the server is idle self.check_memory() self.new_token_ratio = self.init_new_token_ratio @@ -806,7 +808,7 @@ class Scheduler: self.tree_cache, self.model_config, ) - new_batch.prepare_for_extend() + new_batch.prepare_for_extend(self.enable_overlap) # Mixed-style chunked prefill if self.is_mixed_chunk and self.running_batch is not None: @@ -893,14 +895,15 @@ class Scheduler: return ret def process_batch_result(self, batch: ScheduleBatch, result): - if batch.forward_mode.is_idle(): - return if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) if batch.is_empty(): self.running_batch = None - else: + elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result) + elif batch.forward_mode.is_dummy_first(): + batch.next_batch_sampling_info.update_regex_vocab_mask() + batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): @@ -953,6 +956,10 @@ class Scheduler: else: req.is_being_chunked -= 1 + if batch.next_batch_sampling_info: + batch.next_batch_sampling_info.update_regex_vocab_mask() + batch.next_batch_sampling_info.sampling_info_done.set() + else: # embedding or reward model embeddings, bid = result embeddings = embeddings.tolist() @@ -1022,6 +1029,10 @@ class Scheduler: if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + if batch.next_batch_sampling_info: + batch.next_batch_sampling_info.update_regex_vocab_mask() + batch.next_batch_sampling_info.sampling_info_done.set() + self.stream_output(batch.reqs) self.token_to_kv_pool.free_group_end() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 805a687f7..253900f35 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -18,7 +18,6 @@ limitations under the License. import dataclasses import logging import threading -import time from queue import Queue from typing import Optional @@ -96,9 +95,7 @@ class TpModelWorkerClient: @torch.no_grad() def forward_thread_func_(self): while True: - model_worker_batch, future_token_ids_ct, compute_info_done = ( - self.input_queue.get() - ) + model_worker_batch, future_token_ids_ct = self.input_queue.get() if not model_worker_batch: break self.launch_done = threading.Event() @@ -109,7 +106,6 @@ class TpModelWorkerClient: resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward - compute_info_done.wait() logits_output, next_token_ids = self.worker.forward_batch_generation( model_worker_batch, self.launch_done ) @@ -160,15 +156,16 @@ class TpModelWorkerClient: return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + # A cuda stream sync here to avoid the cuda illegal memory access error. + _ = model_worker_batch.seq_lens[0].item() + # Push a new batch to the queue model_worker_batch.sampling_info = dataclasses.replace( - model_worker_batch.sampling_info - ) - compute_info_done = torch.cuda.Event() - compute_info_done.record() - self.input_queue.put( - (model_worker_batch, self.future_token_ids_ct, compute_info_done) + model_worker_batch.sampling_info, + sampling_info_done=threading.Event(), ) + self.cur_sampling_info = model_worker_batch.sampling_info + self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) # Allocate output future objects bs = len(model_worker_batch.seq_lens) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e044dd65e..cd51eb6b2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -52,15 +52,19 @@ if TYPE_CHECKING: class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. PREFILL = auto() - # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). + # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). EXTEND = auto() # Decode one token. DECODE = auto() - # Contains both EXTEND and DECODE. + # Contains both EXTEND and DECODE when doing chunked prefill. MIXED = auto() - # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated. + # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated. IDLE = auto() + # A dummy first batch to start the pipeline for overlap scheduler. + # It is now used for triggering the sampling_info_done event for the first prefill batch. + DUMMY_FIRST = auto() + def is_prefill(self): return self == ForwardMode.PREFILL @@ -76,6 +80,9 @@ class ForwardMode(IntEnum): def is_idle(self): return self == ForwardMode.IDLE + def is_dummy_first(self): + return self == ForwardMode.DUMMY_FIRST + @dataclass class ForwardBatch: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c3e14c1ec..efd4fc214 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -142,7 +142,6 @@ class ModelRunner: "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, - "disable_penalizer": server_args.disable_penalizer, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, } @@ -636,10 +635,18 @@ class ModelRunner: def sample( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ) -> torch.Tensor: - # Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info = forward_batch.sampling_info - sampling_info.update_regex_vocab_mask() - sampling_info.update_penalties() + + if sampling_info.sampling_info_done: + # Overlap mode: the function update_regex_vocab_mask was executed + # in process_batch_result of the last batch. + if sampling_info.grammars: + sampling_info.sampling_info_done.wait() + sampling_info.update_penalties() + else: + # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. + sampling_info.update_regex_vocab_mask() + sampling_info.update_penalties() logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) # Sample the next tokens. diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 41b88e966..6be15e6ac 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -1,12 +1,17 @@ from __future__ import annotations import dataclasses +import logging +import threading from typing import TYPE_CHECKING, Callable, List, Optional import torch import sglang.srt.sampling.penaltylib as penaltylib +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -28,6 +33,7 @@ class SamplingBatchInfo: # Bias Tensors vocab_size: int grammars: Optional[List] = None + sampling_info_done: Optional[threading.Event] = None logit_bias: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None @@ -42,10 +48,7 @@ class SamplingBatchInfo: @classmethod def from_schedule_batch( - cls, - batch: ScheduleBatch, - vocab_size: int, - disable_penalizer: bool, + cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool ): reqs = batch.reqs device = batch.device @@ -79,6 +82,33 @@ class SamplingBatchInfo: ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. + if enable_overlap_schedule: + # TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs, + # so it is kind of tricky to make it work with overlap scheduler. + # It requires correcly updating the penalty logits before the sampling and syncing the events. + # We will support them later. + penalizers = { + penaltylib.BatchedMinNewTokensPenalizer, + } + if ( + any(req.sampling_params.frequency_penalty != 0.0 for req in reqs) + or any(req.sampling_params.presence_penalty != 0.0 for req in reqs) + or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs) + ): + logger.warning( + "frequency_penalty, presence_penalty, and repetition_penalty are not supported " + "when using the default overlap scheduler. They will be ignored. " + "Please add `--disable-overlap` when launching the server if you need these features. " + "The speed will be slower in that case." + ) + else: + penalizers = { + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, + } + # Each penalizers will do nothing if they evaluate themselves as not required by looking at # the sampling_params of the requests (See {_is_required()} of each penalizers). So this # should not add hefty computation overhead other than simple checks. @@ -86,20 +116,12 @@ class SamplingBatchInfo: # While we choose not to even create the class instances if they are not required, this # could add additional complexity to the {ScheduleBatch} class, especially we need to # handle {filter_batch()} and {merge_batch()} cases as well. - if disable_penalizer: - ret.penalizer_orchestrator = None - else: - ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( - vocab_size=vocab_size, - batch=batch, - device=batch.device, - Penalizers={ - penaltylib.BatchedFrequencyPenalizer, - penaltylib.BatchedMinNewTokensPenalizer, - penaltylib.BatchedPresencePenalizer, - penaltylib.BatchedRepetitionPenalizer, - }, - ) + ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=batch, + device=batch.device, + Penalizers=penalizers, + ) # Handle logit bias but only allocate when needed ret.logit_bias = None @@ -133,13 +155,13 @@ class SamplingBatchInfo: self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self): - if not self.grammars or not any(grammar for grammar in self.grammars): + if not self.grammars: self.vocab_mask = None self.apply_mask = None return # find a grammar from the list - grammar = next(grammar for grammar in self.grammars if grammar is not None) + grammar = next(grammar for grammar in self.grammars if grammar) # maybe we can reuse the existing mask? self.vocab_mask = grammar.allocate_vocab_mask( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 204e98da1..75afecbed 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -123,7 +123,6 @@ class ServerArgs: disable_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False - disable_penalizer: bool = False enable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -200,12 +199,7 @@ class ServerArgs: ) if self.enable_overlap_schedule: - logger.warning( - "Overlap scheduler mode is enabled. This is an experimental feature. " - "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), " - "and embedding APIs are not supported and will lead to wrong results. " - ) - self.disable_penalizer = True + self.disable_jump_forward = True @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -622,11 +616,6 @@ class ServerArgs: action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) - parser.add_argument( - "--disable-penalizer", - action="store_true", - help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.", - ) parser.add_argument( "--disable-nan-detection", action="store_true",