diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 0eecc41d8..30a009c2e 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -17,6 +17,9 @@ limitations under the License. import random from collections import defaultdict +from contextlib import contextmanager + +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch class PolicyScheduler: @@ -83,3 +86,122 @@ class PolicyScheduler: for child in childs: self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) q.extend(last_node_to_reqs[cur_node]) + + +class PrefillAdder: + def __init__( + self, + tree_cache, + rem_total_tokens, + rem_input_tokens, + rem_chunk_tokens, + ): + self.tree_cache = tree_cache + self.rem_total_tokens = rem_total_tokens + self.rem_input_tokens = rem_input_tokens + self.rem_chunk_tokens = rem_chunk_tokens + + self.can_run_list = [] + self.new_inflight_req = None + self.log_hit_tokens = 0 + self.log_input_tokens = 0 + + def no_remaining_tokens(self): + return ( + self.rem_total_tokens <= 0 + or self.rem_input_tokens <= 0 + or ( + self.rem_chunk_tokens <= 0 + if self.rem_chunk_tokens is not None + else False + ) + ) + + def remove_running_tokens( + self, running_batch: ScheduleBatch, new_token_ratio: float + ): + self.rem_total_tokens -= sum( + [ + (r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio + for r in running_batch.reqs + ] + ) + + def _prefill_one_req( + self, prefix_len: int, extend_input_len: int, max_new_tokens: int + ): + self.rem_total_tokens -= extend_input_len + max_new_tokens + self.rem_input_tokens -= extend_input_len + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= extend_input_len + + self.log_hit_tokens += prefix_len + self.log_input_tokens += extend_input_len + + def add_inflight_req(self, req: Req): + req.input_ids = req.origin_input_ids + req.output_ids + req.extend_input_len = len(req.input_ids) - len(req.prefix_indices) + truncated = req.extend_input_len > self.rem_chunk_tokens + req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) + req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len] + self.can_run_list.append(req) + + self._prefill_one_req( + len(req.prefix_indices), + req.extend_input_len, + req.sampling_params.max_new_tokens if not truncated else 0, + ) + + # Return if chunked prefill not finished + return req if truncated else None + + @contextmanager + def _lock_node(self, last_node): + try: + delta = self.tree_cache.inc_lock_ref(last_node) + self.rem_total_tokens += delta + yield None + finally: + delta = self.tree_cache.dec_lock_ref(last_node) + self.rem_total_tokens += delta + + def add_one_req(self, req: Req): + total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens + input_tokens = req.extend_input_len + prefix_len = len(req.prefix_indices) + + if total_tokens >= self.rem_total_tokens: + return False + + if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: + return False + + with self._lock_node(req.last_node): + if total_tokens > self.rem_total_tokens: + return False + + if ( + self.rem_chunk_tokens is None + or input_tokens <= self.rem_chunk_tokens + or (req.return_logprob and req.normalized_prompt_logprob is None) + ): + # Non-chunked prefill + self.can_run_list.append(req) + self.tree_cache.inc_lock_ref(req.last_node) + self._prefill_one_req( + prefix_len, input_tokens, req.sampling_params.max_new_tokens + ) + else: + # Chunked prefill + trunc_len = self.rem_chunk_tokens + if trunc_len == 0: + return False + + req.extend_input_len = trunc_len + req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len] + self.can_run_list.append(req) + self.new_inflight_req = req + self.tree_cache.inc_lock_ref(req.last_node) + self._prefill_one_req(prefix_len, trunc_len, 0) + + return True diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8471ad187..7d9091157 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.managers.policy_scheduler import PolicyScheduler +from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, BaseFinishReason, @@ -377,151 +377,57 @@ class ModelTpServer: # Get priority queue self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) - # Add requests if there is available space - can_run_list = [] - new_batch_total_tokens = 0 - new_batch_input_tokens = 0 - - available_size = ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + adder = PrefillAdder( + self.tree_cache, + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), + self.max_prefill_tokens, + self.chunked_prefill_size, ) - if self.running_batch: - available_size -= sum( - [ - (r.sampling_params.max_new_tokens - len(r.output_ids)) - * self.new_token_ratio - for r in self.running_batch.reqs - ] - ) - # Handle the current inflight request - take_inflight = 0 - if self.current_inflight_req: - take_inflight = 1 - r = self.current_inflight_req - r.input_ids = r.origin_input_ids + r.output_ids - truncated = ( - len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size - ) - r.extend_input_len = min( - len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size - ) - r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len] - can_run_list.append(r) + if self.running_batch is not None: + adder.remove_running_tokens(self.running_batch, self.new_token_ratio) - if not truncated: - # Finish inflight - self.current_inflight_req = None - new_batch_total_tokens += ( - r.extend_input_len + r.sampling_params.max_new_tokens - ) - new_batch_input_tokens += r.extend_input_len - else: - new_batch_total_tokens += r.extend_input_len - new_batch_input_tokens += r.extend_input_len + has_inflight = self.current_inflight_req is not None + if self.current_inflight_req is not None: + self.current_inflight_req = adder.add_inflight_req( + self.current_inflight_req + ) for req in self.waiting_queue: - if req.return_logprob and req.normalized_prompt_logprob is None: - # Need at least two tokens to compute normalized logprob - if req.extend_input_len < 2: - delta = 2 - req.extend_input_len - req.extend_input_len += delta - req.prefix_indices = req.prefix_indices[:-delta] - if req.image_offset is not None: - req.image_offset += delta - if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0: - # Need at least one token to compute logits - req.extend_input_len = 1 - req.prefix_indices = req.prefix_indices[:-1] - if req.image_offset is not None: - req.image_offset += 1 - + res = adder.add_one_req(req) if ( - req.extend_input_len - + req.sampling_params.max_new_tokens - + new_batch_total_tokens - < available_size - and ( - req.extend_input_len + new_batch_input_tokens - <= self.max_prefill_tokens - or len(can_run_list) == 0 - ) + not res + or adder.no_remaining_tokens() + or running_bs + len(adder.can_run_list) >= self.max_running_requests ): - delta = self.tree_cache.inc_lock_ref(req.last_node) - available_size += delta - - if not ( - req.extend_input_len - + req.sampling_params.max_new_tokens - + new_batch_total_tokens - < available_size - ): - # Undo locking - delta = self.tree_cache.dec_lock_ref(req.last_node) - available_size += delta - break - else: - # Add this request to the running batch - if ( - self.chunked_prefill_size is None - or ( - new_batch_input_tokens + req.extend_input_len - <= self.chunked_prefill_size - ) - or ( - req.return_logprob and req.normalized_prompt_logprob is None - ) - ): - can_run_list.append(req) - new_batch_total_tokens += ( - req.extend_input_len + req.sampling_params.max_new_tokens - ) - new_batch_input_tokens += req.extend_input_len - else: - trunc_len = self.chunked_prefill_size - new_batch_input_tokens - - if trunc_len <= 0: - # Undo locking - delta = self.tree_cache.dec_lock_ref(req.last_node) - available_size += delta - break - - req.extend_input_len = trunc_len - req.input_ids = req.input_ids[ - : len(req.prefix_indices) + req.extend_input_len - ] - can_run_list.append(req) - self.current_inflight_req = req - new_batch_input_tokens += req.extend_input_len - new_batch_total_tokens += req.extend_input_len - break - else: break - if running_bs + len(can_run_list) >= self.max_running_requests: - break + can_run_list = adder.can_run_list + + if adder.new_inflight_req is not None: + assert self.current_inflight_req is None + self.current_inflight_req = adder.new_inflight_req if len(can_run_list) == 0: return None # Print stats if self.tp_rank == 0: - hit_tokens = sum(len(x.prefix_indices) for x in can_run_list) self.tree_cache_metrics["total"] += ( - hit_tokens + new_batch_input_tokens + adder.log_input_tokens + adder.log_hit_tokens ) / 10**9 - self.tree_cache_metrics["hit"] += hit_tokens / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) logger.info( f"[gpu={self.gpu_id}] Prefill batch. " f"#new-seq: {len(can_run_list)}, " - f"#new-token: {new_batch_input_tokens}, " - f"#cached-token: {hit_tokens}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}" + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" ) # Return the new batch diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6426c8e69..9a285b337 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -130,7 +130,7 @@ class ModelRunner: server_args.max_total_tokens, ) self.init_cublas() - self.init_flash_infer() + self.init_flashinfer() # Capture cuda graphs self.init_cuda_graphs() @@ -287,7 +287,7 @@ class ModelRunner: c = a @ b return c - def init_flash_infer(self): + def init_flashinfer(self): if self.server_args.disable_flashinfer: self.flashinfer_prefill_wrapper_ragged = None self.flashinfer_prefill_wrapper_paged = None diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 539554fa8..db87624d2 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index f3105ad45..f96f7e0e4 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention @@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module): ) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id)