From b121bc03a3c30888caeffd49e96d5ffef473edbf Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 20 Oct 2024 19:47:14 -0700 Subject: [PATCH] Simplify batch result resolution (#1735) --- python/sglang/srt/managers/schedule_batch.py | 27 ++----- python/sglang/srt/managers/scheduler.py | 34 ++++----- .../srt/managers/tp_worker_overlap_thread.py | 74 ++++++++----------- .../sglang/srt/model_executor/model_runner.py | 9 +-- python/sglang/srt/server_args.py | 10 +++ 5 files changed, 64 insertions(+), 90 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d71cf55f6..7fd153e80 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch It contains low-level tensor data. Most of the data consists of GPU tensors. """ +import dataclasses import logging -from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch @@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason): } -@dataclass +@dataclasses.dataclass class ImageInputs: """The image related inputs.""" @@ -407,7 +407,7 @@ class Req: bid = 0 -@dataclass +@dataclasses.dataclass class ScheduleBatch: """Store all inforamtion of a batch.""" @@ -902,7 +902,7 @@ class ScheduleBatch: ) -@dataclass +@dataclasses.dataclass class ModelWorkerBatch: # The batch id bid: int @@ -942,24 +942,7 @@ class ModelWorkerBatch: mrope_positions_delta: List[List[int]] def copy(self): - return ModelWorkerBatch( - bid=self.bid, - forward_mode=self.forward_mode, - input_ids=self.input_ids, - req_pool_indices=self.req_pool_indices, - seq_lens=self.seq_lens, - out_cache_loc=self.out_cache_loc, - req_to_token_pool_records=self.req_to_token_pool_records, - return_logprob=self.return_logprob, - top_logprobs_nums=self.top_logprobs_nums, - extend_seq_lens=self.extend_seq_lens, - extend_prefix_lens=self.extend_prefix_lens, - extend_logprob_start_lens=self.extend_logprob_start_lens, - image_inputs=self.image_inputs, - lora_paths=self.lora_paths, - sampling_info=self.sampling_info.copy(), - mrope_positions_delta=self.mrope_positions_delta, - ) + return dataclasses.replace(self, sampling_info=self.sampling_info.copy()) def to(self, device: str): self.input_ids = self.input_ids.to(device, non_blocking=True) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index df4b5dfb4..990fbeaa8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -149,12 +149,8 @@ class Scheduler: # Launch a tensor parallel worker if self.enable_overlap: TpWorkerClass = TpModelWorkerClient - self.resolve_next_token_ids = ( - lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) - ) else: TpWorkerClass = TpModelWorker - self.resolve_next_token_ids = lambda bid, x: x.tolist() self.tp_worker = TpWorkerClass( server_args=server_args, @@ -756,9 +752,12 @@ class Scheduler: def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: logits_output, next_token_ids, bid = result - if batch.return_logprob: - # Move logprobs to cpu - if logits_output.next_token_logprobs is not None: + + if self.enable_overlap: + logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid) + else: + # Move next_token_ids and logprobs to cpu + if batch.return_logprob: logits_output.next_token_logprobs = ( logits_output.next_token_logprobs[ torch.arange(len(next_token_ids), device=self.device), @@ -771,8 +770,7 @@ class Scheduler: logits_output.normalized_prompt_logprobs = ( logits_output.normalized_prompt_logprobs.tolist() ) - - next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) + next_token_ids = next_token_ids.tolist() # Check finish conditions logprob_pt = 0 @@ -825,14 +823,16 @@ class Scheduler: logits_output, next_token_ids, bid = result self.num_generated_tokens += len(batch.reqs) - # Move logprobs to cpu - if batch.return_logprob: - next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].tolist() - - next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) + if self.enable_overlap: + logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid) + else: + # Move next_token_ids and logprobs to cpu + if batch.return_logprob: + next_token_logprobs = logits_output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=self.device), + next_token_ids, + ].tolist() + next_token_ids = next_token_ids.tolist() self.token_to_kv_pool.free_group_begin() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 5cc130a6f..0141d2113 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -48,19 +48,16 @@ class TpModelWorkerClient: self.max_running_requests = self.worker.max_running_requests self.device = self.worker.device - # Create future mappings - self.future_logits_output_dict = dict() - self.future_logits_output_ct = 0 + # Init future mappings self.future_token_ids_ct = 0 + self.future_token_ids_limit = self.max_running_requests * 3 self.future_token_ids_map = torch.empty( (self.max_running_requests * 5,), dtype=torch.int32, device=self.device ) - self.future_token_ids_limit = self.max_running_requests * 3 - self.future_token_ids_output = dict() # Launch a thread - self.future_event_map = dict() - self.forward_queue = Queue() + self.input_queue = Queue() + self.output_queue = Queue() self.forward_stream = torch.cuda.Stream() self.forward_thread = threading.Thread( target=self.forward_thread_func, @@ -90,9 +87,7 @@ class TpModelWorkerClient: def forward_thread_func_(self): while True: tic1 = time.time() - model_worker_batch, future_logits_output, future_next_token_ids = ( - self.forward_queue.get() - ) + model_worker_batch, future_token_ids_ct = self.input_queue.get() # Resolve future tokens in the input tic2 = time.time() @@ -107,17 +102,22 @@ class TpModelWorkerClient: model_worker_batch ) - # Set future values - if model_worker_batch.return_logprob: - self.future_logits_output_dict[future_logits_output] = logits_output - + # Update the future token ids map + bs = len(model_worker_batch.seq_lens) + future_next_token_ids = torch.arange( + -(future_token_ids_ct + bs), + -(future_token_ids_ct), + dtype=torch.int32, + device=self.device, + ) self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( torch.int32 ) - self.future_token_ids_output[model_worker_batch.bid] = ( - next_token_ids.tolist() - ) - self.future_event_map[model_worker_batch.bid].set() + + # Set the result + next_token_ids = next_token_ids.tolist() + assert logits_output.next_token_logprobs is None, "Not supported" + self.output_queue.put((None, next_token_ids)) if False: tic3 = time.time() @@ -128,38 +128,26 @@ class TpModelWorkerClient: f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}" ) - def resolve_future_token_ids(self, bid: int): - self.future_event_map[bid].wait() - ret = self.future_token_ids_output[bid] - del self.future_event_map[bid] - return ret - - def resolve_future_logits_output(self, future_obj): - return self.future_logits_output_dict.pop(future_obj) + def resulve_batch_result(self, bid: int): + logits_output, next_token_ids = self.output_queue.get() + return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): - # Allocate output future objects - future_logits_output = self.future_logits_output_ct - self.future_logits_output_ct += 1 + # Push a new batch to the queue + self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct)) + # Allocate output future objects bs = len(model_worker_batch.seq_lens) - with torch.cuda.stream(self.forward_stream): - future_next_token_ids = -torch.arange( - self.future_token_ids_ct + 1, - self.future_token_ids_ct + 1 + bs, - dtype=torch.int32, - device=self.device, - ) + future_next_token_ids = torch.arange( + -(self.future_token_ids_ct + bs), + -(self.future_token_ids_ct), + dtype=torch.int32, + device=self.device, + ) self.future_token_ids_ct = ( self.future_token_ids_ct + bs ) % self.future_token_ids_limit - ret = future_logits_output, future_next_token_ids - - self.future_event_map[model_worker_batch.bid] = threading.Event() - self.forward_queue.put( - (model_worker_batch.copy(), future_logits_output, future_next_token_ids) - ) - return ret + return None, future_next_token_ids def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 898e5cc1a..291528e07 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -120,7 +120,7 @@ class ModelRunner: ) if self.is_multimodal_model: - logger.info( + logger.warning( "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." ) server_args.chunked_prefill_size = None @@ -131,13 +131,6 @@ class ModelRunner: ]: server_args.disable_cuda_graph = True - if self.server_args.enable_overlap_schedule: - logger.warning( - "Overlap scheduler is enabled. This is an experimental feature. " - "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), " - "and embedding APIs are not supported and will lead to wrong results." - ) - # Global vars if server_args.show_time_cost: enable_show_time_cost() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 722e30f6b..6ccd89185 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -177,6 +177,16 @@ class ServerArgs: if self.sampling_backend is None: self.sampling_backend = "flashinfer" + if self.enable_overlap_schedule: + logger.warning( + "Overlap scheduler mode is enabled. This is an experimental feature. " + "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), " + "and embedding APIs are not supported and will lead to wrong results. " + "The NaN detection is also disabled." + ) + self.disable_penalizer = True + self.disable_nan_detection = True + # Model-specific patches if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: logger.info(