diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f613f44e5..ca932bef0 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,5 +1,8 @@ """Logits processing.""" +import dataclasses +from typing import List + import torch from torch import nn from vllm.distributed import ( @@ -10,6 +13,24 @@ from vllm.distributed import ( from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata +@dataclasses.dataclass +class LogitProcessorOutput: + # The logits of the next tokens. shape: [#seq, vocab_size] + next_token_logits: torch.Tensor + # The logprobs of the next tokens. shape: [#seq, vocab_size] + next_token_logprobs: torch.Tensor + + # The normlaized logprobs of prompts. shape: [#seq] + normalized_prompt_logprobs: torch.Tensor + # The logprobs of prefill tokens. shape: [#token, vocab_size] + prefill_token_logprobs: torch.Tensor + + # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id) + prefill_top_logprobs: List + # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id) + decode_top_logprobs: List + + class LogitsProcessor(nn.Module): def __init__(self, config): super().__init__() @@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module): return normalized_prompt_logprobs def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata): + # TODO: vectorize the code below if input_metadata.forward_mode == ForwardMode.DECODE: decode_top_logprobs = [] for i in range(all_logprobs.shape[0]): @@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module): else: prefill_top_logprobs, decode_top_logprobs = [], [] pt = 0 - # NOTE: the GPU-CPU overhead can be reduced extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist() for i, extend_seq_len in enumerate(extend_seq_lens_cpu): if extend_seq_len == 0: @@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module): return prefill_top_logprobs, decode_top_logprobs def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): - # Get last index for next token prediction, except for DECODE mode. - last_index = None - if input_metadata.forward_mode != ForwardMode.DECODE: + # Get the last hidden states and last logits for the next token prediction + if input_metadata.forward_mode == ForwardMode.DECODE: + last_index = None + last_hidden = hidden_states + else: last_index = ( torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long) - 1 ) - - # Get the last hidden states and last logits - if input_metadata.forward_mode == ForwardMode.DECODE: - last_hidden = hidden_states - else: last_hidden = hidden_states[last_index] last_logits = torch.matmul(last_hidden, weight.T) @@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module): # Return only last_logits if logprob is not requested if not input_metadata.return_logprob: - hidden_states = None - return last_logits, (None, None, None, None, None) + return LogitProcessorOutput( + next_token_logits=last_logits, + next_token_logprobs=None, + normalized_prompt_logprobs=None, + prefill_token_logprobs=None, + prefill_top_logprobs=None, + decode_top_logprobs=None, + ) else: # When logprob is requested, compute the logits for all tokens. if input_metadata.forward_mode == ForwardMode.DECODE: @@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module): del all_logits all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) + # Get the logprob of top-k tokens return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) if return_top_logprob: prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( @@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module): prefill_top_logprobs = decode_top_logprobs = None if input_metadata.forward_mode == ForwardMode.DECODE: - last_logprobs = all_logprobs - return last_logits, ( - None, - None, - None, - decode_top_logprobs, - last_logprobs, + return LogitProcessorOutput( + next_token_logits=last_logits, + next_token_logprobs=all_logprobs, + normalized_prompt_logprobs=None, + prefill_token_logprobs=None, + prefill_top_logprobs=None, + decode_top_logprobs=decode_top_logprobs, ) else: - # Compute the logprobs for the last token of each request. last_logprobs = all_logprobs[last_index] # Compute the logprobs and normalized logprobs for the prefill tokens. @@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module): normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( prefill_token_logprobs, input_metadata ) - return last_logits, ( - prefill_token_logprobs, - normalized_prompt_logprobs, - prefill_top_logprobs, - decode_top_logprobs, - last_logprobs, + + return LogitProcessorOutput( + next_token_logits=last_logits, + next_token_logprobs=last_logprobs, + normalized_prompt_logprobs=normalized_prompt_logprobs, + prefill_token_logprobs=prefill_token_logprobs, + prefill_top_logprobs=prefill_top_logprobs, + decode_top_logprobs=decode_top_logprobs, ) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index c49d4b01e..82ddb6e48 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -441,33 +441,25 @@ class ModelTpServer: self.model_config.vocab_size, self.int_token_logit_bias ) + # Forward and sample the next tokens if batch.extend_num_tokens != 0: - # Forward - logits, ( - prefill_token_logprobs, - normalized_prompt_logprobs, - prefill_top_logprobs, - decode_top_logprobs, - last_logprobs, - ) = self.model_runner.forward(batch, ForwardMode.EXTEND) - if prefill_token_logprobs is not None: - prefill_token_logprobs = prefill_token_logprobs.tolist() - normalized_prompt_logprobs = normalized_prompt_logprobs.tolist() + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids, _ = batch.sample(output.next_token_logits) - next_token_ids, _ = batch.sample(logits) - - # Only transfer the selected logprobs of the next token to CPU to reduce overhead. - if last_logprobs is not None: - last_token_logprobs = last_logprobs[ - torch.arange(len(batch.reqs), device=next_token_ids.device), + # Move logprobs to cpu + if output.next_token_logprobs is not None: + output.next_token_logprobs = output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, ].tolist() + output.prefill_token_logprobs = output.prefill_token_logprobs.tolist() + output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist() next_token_ids = next_token_ids.tolist() else: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - # Check finish condition + # Check finish conditions pt = 0 for i, req in enumerate(batch.reqs): req.completion_tokens_wo_jump_forward += 1 @@ -475,58 +467,60 @@ class ModelTpServer: req.check_finished() if req.return_logprob: - if req.normalized_prompt_logprob is None: - req.normalized_prompt_logprob = normalized_prompt_logprobs[i] - - if req.prefill_token_logprobs is None: - # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - req.prefill_token_logprobs = list( - zip( - prefill_token_logprobs[pt : pt + req.extend_input_len - 1], - req.input_ids[-req.extend_input_len + 1 :], - ) - ) - if req.logprob_start_len == 0: - req.prefill_token_logprobs = [ - (None, req.input_ids[0]) - ] + req.prefill_token_logprobs - - if req.last_update_decode_tokens != 0: - req.decode_token_logprobs.extend( - list( - zip( - prefill_token_logprobs[ - pt - + req.extend_input_len - - req.last_update_decode_tokens : pt - + req.extend_input_len - - 1 - ], - req.input_ids[-req.last_update_decode_tokens + 1 :], - ) - ) - ) - - req.decode_token_logprobs.append( - (last_token_logprobs[i], next_token_ids[i]) - ) - - if req.top_logprobs_num > 0: - if req.prefill_top_logprobs is None: - req.prefill_top_logprobs = prefill_top_logprobs[i] - if req.logprob_start_len == 0: - req.prefill_top_logprobs = [None] + req.prefill_top_logprobs - - if req.last_update_decode_tokens != 0: - req.decode_top_logprobs.extend( - prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] - ) - req.decode_top_logprobs.append(decode_top_logprobs[i]) - - pt += req.extend_input_len + self.add_logprob_return_values(i, req, pt, next_token_ids, output) + pt += req.extend_input_len self.handle_finished_requests(batch) + def add_logprob_return_values(self, i, req, pt, next_token_ids, output): + if req.normalized_prompt_logprob is None: + req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] + + if req.prefill_token_logprobs is None: + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. + req.prefill_token_logprobs = list( + zip( + output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1], + req.input_ids[-req.extend_input_len + 1 :], + ) + ) + if req.logprob_start_len == 0: + req.prefill_token_logprobs = [ + (None, req.input_ids[0]) + ] + req.prefill_token_logprobs + + if req.last_update_decode_tokens != 0: + req.decode_token_logprobs.extend( + list( + zip( + output.prefill_token_logprobs[ + pt + + req.extend_input_len + - req.last_update_decode_tokens : pt + + req.extend_input_len + - 1 + ], + req.input_ids[-req.last_update_decode_tokens + 1 :], + ) + ) + ) + + req.decode_token_logprobs.append( + (output.next_token_logprobs[i], next_token_ids[i]) + ) + + if req.top_logprobs_num > 0: + if req.prefill_top_logprobs is None: + req.prefill_top_logprobs = output.prefill_top_logprobs[i] + if req.logprob_start_len == 0: + req.prefill_top_logprobs = [None] + req.prefill_top_logprobs + + if req.last_update_decode_tokens != 0: + req.decode_top_logprobs.extend( + output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] + ) + req.decode_top_logprobs.append(output.decode_top_logprobs[i]) + def cache_filled_batch(self, batch: Batch): req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): @@ -540,7 +534,7 @@ class ModelTpServer: req.prefix_indices, req.last_node = new_prefix_indices, new_last_node def forward_decode_batch(self, batch: Batch): - # check if decode out of memory + # Check if decode out of memory if not batch.check_decode_mem(): old_ratio = self.new_token_ratio self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0) @@ -559,9 +553,8 @@ class ModelTpServer: ) if not self.disable_regex_jump_forward: - # check for jump-forward + # Check for jump-forward jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) - self.forward_queue.extend(jump_forward_reqs) if batch.is_empty(): return @@ -570,23 +563,19 @@ class ModelTpServer: self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) batch.prepare_for_decode() - # Forward - logits, ( - _, - _, - _, - decode_top_logprobs, - last_logprobs, - ) = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(logits) - next_token_ids = next_token_ids.tolist() + # Forward and sample the next tokens + output = self.model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids, _ = batch.sample(output.next_token_logits) - # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. - if last_logprobs is not None: - new_token_logprobs = last_logprobs[ - torch.arange(len(batch.reqs)), next_token_ids + # Move logprobs to cpu + if output.next_token_logprobs is not None: + next_token_logprobs = output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), + next_token_ids, ].tolist() + next_token_ids = next_token_ids.tolist() + # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 @@ -594,10 +583,9 @@ class ModelTpServer: req.check_finished() if req.return_logprob: - req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) - - if req.top_logprobs_num > 0: - req.decode_top_logprobs.append(decode_top_logprobs[i]) + req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id)) + if req.top_logprobs_num > 0: + req.decode_top_logprobs.append(output.decode_top_logprobs[i]) self.handle_finished_requests(batch) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index cdcd33be1..9b23c799f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg try: requests.get(url + "/get_model_info", timeout=5, headers=headers) break - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: pass # Send a warmup request @@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg "text": "The capital city of France is", "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": 8, }, }, headers=headers, timeout=600, ) assert res.status_code == 200 - except Exception: + except Exception as e: if pipe_finish_writer is not None: pipe_finish_writer.send(get_exception_traceback()) print(f"Initialization failed. warmup error: {e}")