diff --git a/benchmark/tree_of_thought_deep/bench_sglang.py b/benchmark/tree_of_thought_deep/bench_sglang.py index b60f1f00f..bfb2a4113 100644 --- a/benchmark/tree_of_thought_deep/bench_sglang.py +++ b/benchmark/tree_of_thought_deep/bench_sglang.py @@ -103,6 +103,7 @@ def tree_search(s, question, num_branches): def main(args): lines = read_jsonl(args.data_path) + lines = list(lines) # Construct prompts num_branches = 2 diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 654c944ca..f1055dcb4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -226,8 +226,9 @@ class Req: else origin_input_ids # Before image padding ) self.origin_input_ids = origin_input_ids - self.output_ids = [] # Each decode stage's output ids - self.fill_ids = None # fill_ids = origin_input_ids + output_ids + # Each decode stage's output ids + self.output_ids = [] + # fill_ids = origin_input_ids + output_ids. Updated if chunked. self.session_id = session_id self.input_embeds = input_embeds @@ -265,6 +266,7 @@ class Req: # Prefix info self.prefix_indices = [] # Tokens to run prefill. input_tokens - shared_prefix_tokens. + # Updated if chunked. self.extend_input_len = 0 self.last_node = None @@ -280,10 +282,10 @@ class Req: self.top_logprobs_num = top_logprobs_num # Logprobs (return value) - self.input_token_logprobs_val = None - self.input_token_logprobs_idx = None - self.input_top_logprobs_val = None - self.input_top_logprobs_idx = None + self.input_token_logprobs_val: Optional[List[float]] = None + self.input_token_logprobs_idx: Optional[List[int]] = None + self.input_top_logprobs_val: Optional[List[float]] = None + self.input_top_logprobs_idx: Optional[List[int]] = None if return_logprob: self.output_token_logprobs_val = [] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 62dc22ef2..9bebbcd92 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -22,8 +22,9 @@ import time import warnings from collections import deque from concurrent import futures +from dataclasses import dataclass from types import SimpleNamespace -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import psutil import setproctitle @@ -102,6 +103,19 @@ logger = logging.getLogger(__name__) test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") +@dataclass +class GenerationBatchResult: + logits_output: LogitsProcessorOutput + next_token_ids: List[int] + bid: int + + +@dataclass +class EmbeddingBatchResult: + embeddings: torch.Tensor + bid: int + + class Scheduler: """A scheduler that manages a tensor parallel GPU worker.""" @@ -411,16 +425,16 @@ class Scheduler: self.watchdog_last_time = time.time() while True: + current = time.time() if self.cur_batch is not None: if self.watchdog_last_forward_ct == self.forward_ct: - if time.time() > self.watchdog_last_time + self.watchdog_timeout: + if current > self.watchdog_last_time + self.watchdog_timeout: logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") break else: self.watchdog_last_forward_ct = self.forward_ct - self.watchdog_last_time = time.time() - time.sleep(self.watchdog_timeout / 2) - + self.watchdog_last_time = current + time.sleep(self.watchdog_timeout // 2) # Wait sometimes so that the parent process can print the error. time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) @@ -1018,7 +1032,9 @@ class Scheduler: batch.prepare_for_decode() return batch - def run_batch(self, batch: ScheduleBatch): + def run_batch( + self, batch: ScheduleBatch + ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 @@ -1040,15 +1056,26 @@ class Scheduler: else: assert False, "batch.extend_num_tokens == 0, this is unexpected!" batch.output_ids = next_token_ids - ret = logits_output, next_token_ids, model_worker_batch.bid + + ret = GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + bid=model_worker_batch.bid, + ) else: # embedding or reward model assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - ret = embeddings, model_worker_batch.bid + ret = EmbeddingBatchResult( + embeddings=embeddings, bid=model_worker_batch.bid + ) return ret - def process_batch_result(self, batch: ScheduleBatch, result): + def process_batch_result( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) if batch.is_empty(): @@ -1057,17 +1084,29 @@ class Scheduler: self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_idle(): if self.enable_overlap: - self.tp_worker.resolve_batch_result(result[-1]) + self.tp_worker.resolve_batch_result(result.bid) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - def process_batch_result_prefill(self, batch: ScheduleBatch, result): + def process_batch_result_prefill( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): skip_stream_req = None if self.is_generation: - logits_output, next_token_ids, bid = result + ( + logits_output, + next_token_ids, + bid, + ) = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) if self.enable_overlap: logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) @@ -1125,7 +1164,7 @@ class Scheduler: batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model - embeddings, bid = result + embeddings, bid = result.embeddings, result.bid embeddings = embeddings.tolist() # Check finish conditions @@ -1149,8 +1188,16 @@ class Scheduler: self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) - def process_batch_result_decode(self, batch: ScheduleBatch, result): - logits_output, next_token_ids, bid = result + def process_batch_result_decode( + self, + batch: ScheduleBatch, + result: GenerationBatchResult, + ): + logits_output, next_token_ids, bid = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d238c9195..6e7e69cff 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, + get_attention_tp_size, initialize_dp_attention, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -532,7 +533,7 @@ class ModelRunner: ) else: cell_size = ( - self.model_config.get_num_kv_heads(self.tp_size) + self.model_config.get_num_kv_heads(get_attention_tp_size()) * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 @@ -626,7 +627,7 @@ class ModelRunner: self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, @@ -637,7 +638,7 @@ class ModelRunner: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 4fbe20846..2ed9006c0 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -180,6 +180,7 @@ class CompletionRequest(BaseModel): ignore_eos: bool = False skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None class CompletionResponseChoice(BaseModel): @@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: bool = False skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None class FunctionResponse(BaseModel): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index af0f2a08d..a41b94301 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -842,7 +842,6 @@ class Engine: generator = ret.body_iterator async def generator_wrapper(): - offset = 0 while True: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index df98bdeb3..9d4ec90e9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -239,8 +239,8 @@ class ServerArgs: # Others if self.enable_dp_attention: - assert self.tp_size % self.dp_size == 0 self.dp_size = self.tp_size + assert self.tp_size % self.dp_size == 0 self.chunked_prefill_size = self.chunked_prefill_size // 2 self.schedule_conservativeness = self.schedule_conservativeness * 0.3 logger.warning(