From 59cbf476264d1385405dba4db12effda32cc2053 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 19 Oct 2024 23:19:26 -0700 Subject: [PATCH] Unify the memory pool api and tp worker API (#1724) --- python/sglang/srt/managers/schedule_batch.py | 36 +++++++++++++------ python/sglang/srt/managers/scheduler.py | 32 ++++++++++++----- python/sglang/srt/managers/tp_worker.py | 8 +++-- python/sglang/srt/mem_cache/memory_pool.py | 8 ++++- python/sglang/srt/mem_cache/radix_cache.py | 7 ++-- .../srt/model_executor/forward_batch_info.py | 2 ++ .../sglang/srt/model_executor/model_runner.py | 7 ++++ .../srt/sampling/sampling_batch_info.py | 12 ++++++- 8 files changed, 87 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5f32614d0..842e609b3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch - ScheduleBatch is managed by `scheduler.py::Scheduler`. It contains high-level scheduling data. Most of the data is on the CPU. - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. + It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU. + It will be transformed from CPU scheduler to GPU model runner. - ForwardBatch is managed by `model_runner.py::ModelRunner`. It contains low-level tensor data. Most of the data consists of GPU tensors. """ @@ -522,12 +524,12 @@ class ScheduleBatch: assert seq_len - pre_len == req.extend_input_len if pre_len > 0: - self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = ( - req.prefix_indices + self.req_to_token_pool.write( + (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices ) - - self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( - out_cache_loc[pt : pt + req.extend_input_len] + self.req_to_token_pool.write( + (req.req_pool_idx, slice(pre_len, seq_len)), + out_cache_loc[pt : pt + req.extend_input_len], ) # Compute the relative logprob_start_len in an extend batch @@ -765,9 +767,8 @@ class ScheduleBatch: # Alloc mem bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) - - self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( - self.out_cache_loc + self.req_to_token_pool.write( + (self.req_pool_indices, self.seq_lens), self.out_cache_loc ) self.seq_lens.add_(1) @@ -848,7 +849,6 @@ class ScheduleBatch: extend_logprob_start_lens = self.extend_logprob_start_lens image_inputs = [r.image_inputs for r in self.reqs] - lora_paths = [req.lora_path for req in self.reqs] if self.has_regex: self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] self.sampling_info.regex_fsm_states = [ @@ -869,13 +869,14 @@ class ScheduleBatch: req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, + req_to_token_pool_records=self.req_to_token_pool.get_write_records(), return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, image_inputs=image_inputs, - lora_paths=lora_paths, + lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, mrope_positions_delta=mrope_positions_delta, ) @@ -911,6 +912,9 @@ class ModelWorkerBatch: # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor + # The memory pool operation records + req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]] + # For logprob return_logprob: bool top_logprobs_nums: Optional[List[int]] @@ -940,6 +944,7 @@ class ModelWorkerBatch: req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens.clone(), out_cache_loc=self.out_cache_loc, + req_to_token_pool_records=self.req_to_token_pool_records, return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, extend_seq_lens=self.extend_seq_lens, @@ -950,3 +955,14 @@ class ModelWorkerBatch: sampling_info=self.sampling_info.copy(), mrope_positions_delta=self.mrope_positions_delta, ) + + def to(self, device: str): + self.input_ids = self.input_ids.to(device, non_blocking=True) + self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True) + self.seq_lens = self.seq_lens.to(device, non_blocking=True) + self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True) + self.req_to_token_pool_records = [ + (x, y.to(device, non_blocking=True)) + for x, y in self.req_to_token_pool_records + ] + self.sampling_info.to(device) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e9a8dc3d3..0a9e9db0d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import ( ImageInputs, Req, ScheduleBatch, + global_server_args_dict, ) from sglang.srt.managers.schedule_policy import ( AddReqResult, @@ -144,25 +145,27 @@ class Scheduler: ) # Launch a tensor parallel worker - self.tp_worker = TpModelWorker( + if self.server_args.enable_overlap_schedule: + TpWorkerClass = TpModelWorker + else: + TpWorkerClass = TpModelWorker + self.tp_worker = TpWorkerClass( server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, dp_rank=dp_rank, nccl_port=port_args.nccl_port, ) - - # Init states for overlap schedule if self.server_args.enable_overlap_schedule: - self.forward_batch_generation = ( - self.tp_worker.forward_batch_generation_non_blocking - ) self.resolve_next_token_ids = ( lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) ) + self.forward_batch_generation = ( + self.tp_worker.forward_batch_generation_non_blocking + ) else: - self.forward_batch_generation = self.tp_worker.forward_batch_generation self.resolve_next_token_ids = lambda bid, x: x.tolist() + self.forward_batch_generation = self.tp_worker.forward_batch_generation # Get token and memory info from the model worker ( @@ -172,9 +175,14 @@ class Scheduler: self.max_req_input_len, self.random_seed, self.device, - ) = self.tp_worker.get_token_and_memory_info() + worker_global_server_args_dict, + _, + _, + _, + ) = self.tp_worker.get_worker_info() self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() + global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) # Print debug info @@ -266,6 +274,7 @@ class Scheduler: @torch.inference_mode() def event_loop_normal(self): + """A normal blocking scheduler loop.""" self.last_batch = None while True: @@ -296,6 +305,7 @@ class Scheduler: @torch.inference_mode() def event_loop_overlap(self): + """A scheduler loop that overlaps the CPU processing and GPU computation.""" result_queue = deque() self.last_batch = None @@ -572,6 +582,7 @@ class Scheduler: else set([]) ) + # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( self.lora_paths @@ -673,6 +684,7 @@ class Scheduler: return new_batch def update_running_batch(self): + """Update the current running decoding batch.""" global test_retract batch = self.running_batch @@ -712,6 +724,7 @@ class Scheduler: batch.prepare_for_decode() def run_batch(self, batch: ScheduleBatch): + """Run a batch.""" if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: model_worker_batch = batch.get_model_worker_batch() @@ -933,6 +946,7 @@ class Scheduler: return num_input_logprobs def stream_output(self, reqs: List[Req]): + """Stream the output to detokenizer.""" output_rids = [] output_meta_info = [] output_finished_reason: List[BaseFinishReason] = [] @@ -1030,6 +1044,7 @@ class Scheduler: ) def flush_cache(self): + """Flush the memory pool and cache.""" if len(self.waiting_queue) == 0 and ( self.running_batch is None or len(self.running_batch.reqs) == 0 ): @@ -1070,6 +1085,7 @@ class Scheduler: break def update_weights(self, recv_req: UpdateWeightReqInput): + """In-place update of the weights.""" success, message = self.tp_worker.update_weights(recv_req) if success: flash_cache_success = self.flush_cache() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index bf9155f1e..2c390d9a8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -27,7 +27,7 @@ import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import UpdateWeightReqInput -from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs @@ -111,7 +111,7 @@ class TpModelWorker: if server_args.enable_overlap_schedule: self.init_overlap_status() - def get_token_and_memory_info(self): + def get_worker_info(self): return ( self.max_total_num_tokens, self.max_prefill_tokens, @@ -119,6 +119,10 @@ class TpModelWorker: self.max_req_input_len, self.random_seed, self.device, + global_server_args_dict, + self.model_runner.req_to_token_pool.size, + self.model_runner.req_to_token_pool.max_context_len, + self.model_runner.token_to_kv_pool.size, ) def get_pad_input_ids_func(self): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index f5ae3b00a..21bf4f4b1 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -56,6 +56,12 @@ class ReqToTokenPool: def clear(self): self.free_slots = list(range(self.size)) + def write(self, indices, values): + self.req_to_token[indices] = values + + def get_write_records(self): + return None + class BaseTokenToKVPool: """A memory pool that maps a token to its kv cache locations""" @@ -68,12 +74,12 @@ class BaseTokenToKVPool: ): self.size = size self.dtype = dtype - self.device = device if dtype == torch.float8_e5m2: # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 else: self.store_dtype = dtype + self.device = device self.free_slots = None self.is_not_in_free_group = True diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index ca294c3bd..8cd8354b6 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache): # The prefix indices could be updated, reuse it new_indices, new_last_node = self.match_prefix(token_ids) assert len(new_indices) == len(token_ids) - self.req_to_token_pool.req_to_token[ - req.req_pool_idx, len(req.prefix_indices) : len(new_indices) - ] = new_indices[len(req.prefix_indices) :] + self.req_to_token_pool.write( + (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), + new_indices[len(req.prefix_indices) :], + ) self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 629d5cb3a..eaf268cc2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch - ScheduleBatch is managed by `scheduler.py::Scheduler`. It contains high-level scheduling data. Most of the data is on the CPU. - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. + It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU. + It will be transformed from CPU scheduler to GPU model runner. - ForwardBatch is managed by `model_runner.py::ModelRunner`. It contains low-level tensor data. Most of the data consists of GPU tensors. """ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4250880c3..4bab0db79 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -131,6 +131,13 @@ class ModelRunner: ]: server_args.disable_cuda_graph = True + if self.server_args.enable_overlap_schedule: + logger.warning( + "Overlap scheduler is enabled. This is an experimental feature. " + "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), " + "and embedding APIs are not supported and will lead to wrong results." + ) + # Global vars if server_args.show_time_cost: enable_show_time_cost() diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index e6a593fcc..5f5f92ca0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -78,7 +78,7 @@ class SamplingBatchInfo: need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=top_ks.max().item() <= 1, vocab_size=vocab_size, - device=batch.input_ids.device, + device=device, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -224,3 +224,13 @@ class SamplingBatchInfo: vocab_size=self.vocab_size, device=self.device, ) + + def to(self, device: str): + for item in [ + "temperatures", + "top_ps", + "top_ks", + "min_ps", + ]: + value = getattr(self, item) + setattr(self, item, value.to(device, non_blocking=True))