diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 3d8259249..5bb52f5bb 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -100,9 +100,154 @@ class LogitsProcessor(nn.Module): self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 ) + self.final_logit_softcapping = getattr( + self.config, "final_logit_softcapping", None + ) - def _get_normalized_prompt_logprobs( + def forward( self, + input_ids, + hidden_states, + lm_head: VocabParallelEmbedding, + logits_metadata: Union[LogitsMetadata, ForwardBatch], + ): + if isinstance(logits_metadata, ForwardBatch): + logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) + assert isinstance(logits_metadata, LogitsMetadata) + + # Get the last hidden states and last logits for the next token prediction + if logits_metadata.forward_mode.is_decode(): + last_index = None + last_hidden = hidden_states + else: + last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 + last_hidden = hidden_states[last_index] + + last_logits = self._get_logits(last_hidden, lm_head) + if self.do_tensor_parallel_all_gather: + last_logits = tensor_model_parallel_all_gather(last_logits) + last_logits = last_logits[:, : self.config.vocab_size].float() + + if self.final_logit_softcapping: + last_logits.div_(self.final_logit_softcapping) + torch.tanh(last_logits, out=last_logits) + last_logits.mul_(self.final_logit_softcapping) + + # Return only last_logits if logprob is not requested + if not logits_metadata.return_logprob: + return LogitsProcessorOutput( + next_token_logits=last_logits, + ) + else: + last_logprobs = self.compute_temp_top_p_normalized_logprobs( + last_logits, logits_metadata + ) + + if logits_metadata.forward_mode.is_decode(): + if logits_metadata.return_top_logprob: + output_top_logprobs_val, output_top_logprobs_idx = ( + self.get_top_logprobs(last_logprobs, logits_metadata)[2:4] + ) + else: + output_top_logprobs_val = output_top_logprobs_idx = None + return LogitsProcessorOutput( + next_token_logits=last_logits, + next_token_logprobs=last_logprobs, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, + ) + else: + # Slice the requested tokens to compute logprob + pt, states, pruned_input_ids = 0, [], [] + for start_len, extend_len in zip( + logits_metadata.extend_logprob_start_lens_cpu, + logits_metadata.extend_seq_lens_cpu, + ): + states.append(hidden_states[pt + start_len : pt + extend_len]) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + # Compute the logits and logprobs for all required tokens + states = torch.cat(states, dim=0) + all_logits = self._get_logits(states, lm_head) + if self.do_tensor_parallel_all_gather: + all_logits = tensor_model_parallel_all_gather(all_logits) + + # The LM head's weights may be zero-padded for parallelism. Remove any + # extra logits that this padding may have produced. + all_logits = all_logits[:, : self.config.vocab_size].float() + + if self.final_logit_softcapping: + all_logits.div_(self.final_logit_softcapping) + torch.tanh(all_logits, out=all_logits) + all_logits.mul_(self.final_logit_softcapping) + + all_logprobs = all_logits + del all_logits, hidden_states + + all_logprobs = self.compute_temp_top_p_normalized_logprobs( + all_logprobs, logits_metadata + ) + + # Get the logprob of top-k tokens + if logits_metadata.return_top_logprob: + ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) = self.get_top_logprobs(all_logprobs, logits_metadata) + else: + input_top_logprobs_val = input_top_logprobs_idx = ( + output_top_logprobs_val + ) = output_top_logprobs_idx = None + + # Compute the normalized logprobs for the requested tokens. + # Note that we pad a zero at the end for easy batching. + input_token_logprobs = all_logprobs[ + torch.arange(all_logprobs.shape[0], device="cuda"), + torch.cat( + [ + torch.cat(pruned_input_ids)[1:], + torch.tensor([0], device="cuda"), + ] + ), + ] + normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( + input_token_logprobs, + logits_metadata, + ) + + return LogitsProcessorOutput( + next_token_logits=last_logits, + next_token_logprobs=last_logprobs, + normalized_prompt_logprobs=normalized_prompt_logprobs, + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, + ) + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if hasattr(lm_head, "weight"): + logits = torch.matmul(hidden_states, lm_head.weight.T) + else: + # GGUF models + logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + + # Optional scaling factor + if self.logit_scale is not None: + logits.mul_(self.logit_scale) # In-place multiply + return logits + + @staticmethod + def _get_normalized_prompt_logprobs( input_token_logprobs: torch.Tensor, logits_metadata: LogitsMetadata, ): @@ -177,142 +322,11 @@ class LogitsProcessor(nn.Module): output_top_logprobs_idx, ) - def forward( - self, - input_ids, - hidden_states, - lm_head: VocabParallelEmbedding, - logits_metadata: Union[LogitsMetadata, ForwardBatch], - ): - if isinstance(logits_metadata, ForwardBatch): - logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) - assert isinstance(logits_metadata, LogitsMetadata) - - # Get the last hidden states and last logits for the next token prediction - if logits_metadata.forward_mode.is_decode(): - last_index = None - last_hidden = hidden_states - else: - last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - last_hidden = hidden_states[last_index] - - last_logits = self._get_logits(last_hidden, lm_head) - if self.do_tensor_parallel_all_gather: - last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size].float() - - if hasattr(self.config, "final_logit_softcapping"): - last_logits.div_(self.config.final_logit_softcapping) - torch.tanh(last_logits, out=last_logits) - last_logits.mul_(self.config.final_logit_softcapping) - - # Return only last_logits if logprob is not requested - if not logits_metadata.return_logprob: - return LogitsProcessorOutput( - next_token_logits=last_logits, - ) - else: - last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) - - if logits_metadata.forward_mode.is_decode(): - if logits_metadata.return_top_logprob: - output_top_logprobs_val, output_top_logprobs_idx = ( - self.get_top_logprobs(last_logprobs, logits_metadata)[2:4] - ) - else: - output_top_logprobs_val = output_top_logprobs_idx = None - return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=last_logprobs, - output_top_logprobs_val=output_top_logprobs_val, - output_top_logprobs_idx=output_top_logprobs_idx, - ) - else: - # Slice the requested tokens to compute logprob - pt, states, pruned_input_ids = 0, [], [] - for start_len, extend_len in zip( - logits_metadata.extend_logprob_start_lens_cpu, - logits_metadata.extend_seq_lens_cpu, - ): - states.append(hidden_states[pt + start_len : pt + extend_len]) - pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) - pt += extend_len - - # Compute the logits and logprobs for all required tokens - states = torch.cat(states, dim=0) - all_logits = self._get_logits(states, lm_head) - if self.do_tensor_parallel_all_gather: - all_logits = tensor_model_parallel_all_gather(all_logits) - - # The LM head's weights may be zero-padded for parallelism. Remove any - # extra logits that this padding may have produced. - all_logits = all_logits[:, : self.config.vocab_size].float() - - if hasattr(self.config, "final_logit_softcapping"): - all_logits.div_(self.config.final_logit_softcapping) - torch.tanh(all_logits, out=all_logits) - all_logits.mul_(self.config.final_logit_softcapping) - - all_logprobs = all_logits - del all_logits, hidden_states - all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) - - # Get the logprob of top-k tokens - if logits_metadata.return_top_logprob: - ( - input_top_logprobs_val, - input_top_logprobs_idx, - output_top_logprobs_val, - output_top_logprobs_idx, - ) = self.get_top_logprobs(all_logprobs, logits_metadata) - else: - input_top_logprobs_val = input_top_logprobs_idx = ( - output_top_logprobs_val - ) = output_top_logprobs_idx = None - - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. - input_token_logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat( - [ - torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device="cuda"), - ] - ), - ] - normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, - logits_metadata, - ) - - return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=normalized_prompt_logprobs, - input_token_logprobs=input_token_logprobs, - input_top_logprobs_val=input_top_logprobs_val, - input_top_logprobs_idx=input_top_logprobs_idx, - output_top_logprobs_val=output_top_logprobs_val, - output_top_logprobs_idx=output_top_logprobs_idx, - ) - - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, + @staticmethod + def compute_temp_top_p_normalized_logprobs( + last_logits: torch.Tensor, logits_metadata: LogitsMetadata ) -> torch.Tensor: - if hasattr(lm_head, "weight"): - logits = torch.matmul(hidden_states, lm_head.weight.T) - else: - # GGUF models - logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) - - # Optional scaling factor, backported from vLLM 0.4 - if self.logit_scale is not None: - logits.mul_(self.logit_scale) # In-place multiply - return logits + return torch.nn.functional.log_softmax(last_logits, dim=-1) def test(): diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b0dfda3e8..8a4dcc8ae 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -51,7 +51,6 @@ class Sampler(nn.Module): # Post process logits logits.div_(sampling_info.temperatures) probs = torch.softmax(logits, dim=-1) - logits = None del logits if global_server_args_dict["sampling_backend"] == "flashinfer": @@ -84,6 +83,7 @@ class Sampler(nn.Module): sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps, + sampling_info.need_min_p_sampling, ) else: raise ValueError( @@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch( top_ks: torch.Tensor, top_ps: torch.Tensor, min_ps: torch.Tensor, + need_min_p_sampling: bool, ): """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) - min_p_thresholds = probs_sort[:, 0] * min_ps - probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 probs_sort[ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) ] = 0.0 - probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 - probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + + if need_min_p_sampling: + min_p_thresholds = probs_sort[:, 0] * min_ps + probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 + sampled_index = torch.multinomial(probs_sort, num_samples=1) # int32 range is enough to represent the token ids probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids + + +def top_p_normalize_probs( + probs: torch.Tensor, + top_ps: torch.Tensor, +): + if global_server_args_dict["sampling_backend"] == "flashinfer": + return top_p_renorm_prob(probs, top_ps) + elif global_server_args_dict["sampling_backend"] == "pytorch": + # See also top_k_top_p_min_p_sampling_from_probs_torch + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" + ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 65c029a16..59e31410c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1086,9 +1086,9 @@ class ScheduleBatch: self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.reqs.extend(other.reqs) - self.return_logprob = self.return_logprob or other.return_logprob - self.has_stream = self.has_stream or other.has_stream - self.has_grammar = self.has_grammar or other.has_grammar + self.return_logprob |= other.return_logprob + self.has_stream |= other.has_stream + self.has_grammar |= other.has_grammar def get_model_worker_batch(self): if self.forward_mode.is_decode() or self.forward_mode.is_idle(): @@ -1115,7 +1115,6 @@ class ScheduleBatch: seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, seq_lens_sum=self.seq_lens_sum, - req_to_token_pool_records=self.req_to_token_pool.get_write_records(), return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, global_num_tokens=self.global_num_tokens, @@ -1170,9 +1169,6 @@ class ModelWorkerBatch: # The sum of all sequence lengths seq_lens_sum: int - # The memory pool operation records - req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]] - # For logprob return_logprob: bool top_logprobs_nums: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 93c3b250c..a113fb9c0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -387,8 +387,14 @@ class CudaGraphRunner: # Extract logprobs if forward_batch.return_logprob: - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 + logits_metadata = LogitsMetadata( + forward_mode=ForwardMode.DECODE, + top_logprobs_nums=forward_batch.top_logprobs_nums, + ) + next_token_logprobs = ( + LogitsProcessor.compute_temp_top_p_normalized_logprobs( + next_token_logits, logits_metadata + ) ) logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, @@ -396,10 +402,6 @@ class CudaGraphRunner: ) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) if return_top_logprob: - logits_metadata = LogitsMetadata( - forward_mode=ForwardMode.DECODE, - top_logprobs_nums=forward_batch.top_logprobs_nums, - ) ( logits_output.output_top_logprobs_val, logits_output.output_top_logprobs_idx, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 943233a4f..4c751809a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -698,11 +698,6 @@ class ServerArgs: action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) - parser.add_argument( - "--disable-nan-detection", - action="store_true", - help="Disable the NaN detection for better performance.", - ) parser.add_argument( "--disable-overlap-schedule", action="store_true",