From 9c6ba2484f03be55aa3732b5be50ad062a2d8720 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 30 Dec 2024 04:51:38 -0800 Subject: [PATCH] Refactor logprob computation to return the real logprob used in sampling (#2664) --- python/sglang/srt/layers/logits_processor.py | 374 ++++++++---------- python/sglang/srt/layers/sampler.py | 76 +++- python/sglang/srt/managers/scheduler.py | 24 +- .../srt/managers/tp_worker_overlap_thread.py | 7 +- .../srt/model_executor/cuda_graph_runner.py | 33 +- .../sglang/srt/model_executor/model_runner.py | 41 +- .../srt/sampling/sampling_batch_info.py | 23 ++ .../test_srt_endpoint_with_penalizers.py | 4 +- test/srt/test_srt_endpoint.py | 35 ++ 9 files changed, 305 insertions(+), 312 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 31820d37a..ac3a4a4cc 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -17,6 +17,8 @@ import dataclasses from typing import List, Optional, Union import torch +import triton +import triton.language as tl from torch import nn from vllm.distributed import ( get_tensor_model_parallel_world_size, @@ -33,76 +35,77 @@ from sglang.srt.model_executor.forward_batch_info import ( @dataclasses.dataclass class LogitsProcessorOutput: + ## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor - # The logprobs of the next tokens. shape: [#seq, vocab_size] - next_token_logprobs: torch.Tensor = None + # Used by speculative decoding (EAGLE) + # The last hidden layers + hidden_states: Optional[torch.Tensor] = None + ## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler. + # The logprobs of the next tokens. shape: [#seq] + next_token_logprobs: Optional[torch.Tensor] = None + # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] + next_token_top_logprobs_val: Optional[List] = None + next_token_top_logprobs_idx: Optional[List] = None + + ## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only. # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor = None - # The logprobs of input tokens. shape: [#token, vocab_size] + # The logprobs of input tokens. shape: [#token] input_token_logprobs: torch.Tensor = None - - # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] + # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] input_top_logprobs_val: List = None input_top_logprobs_idx: List = None - # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] - output_top_logprobs_val: List = None - output_top_logprobs_idx: List = None - - # Used by speculative decoding (EAGLE) - # The output of transformer layers - hidden_states: Optional[torch.Tensor] = None @dataclasses.dataclass class LogitsMetadata: forward_mode: ForwardMode - top_logprobs_nums: Optional[List[int]] - - return_logprob: bool = False - return_top_logprob: bool = False + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL + extend_return_logprob: bool = False + extend_return_top_logprob: bool = False extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens_cpu: Optional[List[int]] = None - extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None - - capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL + top_logprobs_nums: Optional[List[int]] = None @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): - extend_logprob_pruned_lens_cpu = None - - if forward_batch.return_logprob: - return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) - if forward_batch.forward_mode.is_extend(): - extend_logprob_pruned_lens_cpu = [ - extend_len - start_len - for extend_len, start_len in zip( - forward_batch.extend_seq_lens_cpu, - forward_batch.extend_logprob_start_lens_cpu, - ) - ] - else: - return_top_logprob = False - if forward_batch.spec_info: capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode else: capture_hidden_mode = CaptureHiddenMode.NULL + if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob: + extend_return_logprob = True + extend_return_top_logprob = any( + x > 0 for x in forward_batch.top_logprobs_nums + ) + extend_logprob_pruned_lens_cpu = [ + extend_len - start_len + for extend_len, start_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.extend_logprob_start_lens_cpu, + ) + ] + else: + extend_return_logprob = extend_return_top_logprob = ( + extend_logprob_pruned_lens_cpu + ) = False + return cls( forward_mode=forward_batch.forward_mode, - top_logprobs_nums=forward_batch.top_logprobs_nums, - return_logprob=forward_batch.return_logprob, - return_top_logprob=return_top_logprob, + capture_hidden_mode=capture_hidden_mode, + extend_return_logprob=extend_return_logprob, + extend_return_top_logprob=extend_return_top_logprob, extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, - capture_hidden_mode=capture_hidden_mode, + top_logprobs_nums=forward_batch.top_logprobs_nums, ) @@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module): ): if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) - assert isinstance(logits_metadata, LogitsMetadata) # Get the last hidden states and last logits for the next token prediction if ( @@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module): last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] + # Compute logits last_logits = self._get_logits(last_hidden, lm_head) - if self.do_tensor_parallel_all_gather: - last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size].float() - - if self.final_logit_softcapping: - last_logits.div_(self.final_logit_softcapping) - torch.tanh(last_logits, out=last_logits) - last_logits.mul_(self.final_logit_softcapping) - - # Return only last_logits if logprob is not requested - if not logits_metadata.return_logprob: + if not logits_metadata.extend_return_logprob: + # Decode mode or extend mode without return_logprob. return LogitsProcessorOutput( next_token_logits=last_logits, hidden_states=( @@ -167,95 +161,60 @@ class LogitsProcessor(nn.Module): ), ) else: - last_logprobs = self.compute_temp_top_p_normalized_logprobs( - last_logits, logits_metadata + # Slice the requested tokens to compute logprob + pt, pruned_states, pruned_input_ids = 0, [], [] + for start_len, extend_len in zip( + logits_metadata.extend_logprob_start_lens_cpu, + logits_metadata.extend_seq_lens_cpu, + ): + pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + # Compute the logits of all required tokens + pruned_states = torch.cat(pruned_states) + del hidden_states + input_token_logits = self._get_logits(pruned_states, lm_head) + del pruned_states + + # Normalize the logprob w/o temperature, top-p + input_logprobs = input_token_logits + input_logprobs = self.compute_temp_top_p_normalized_logprobs( + input_logprobs, logits_metadata ) - if logits_metadata.forward_mode.is_decode(): - if logits_metadata.return_top_logprob: - output_top_logprobs_val, output_top_logprobs_idx = ( - self.get_top_logprobs(last_logprobs, logits_metadata)[2:4] - ) - else: - output_top_logprobs_val = output_top_logprobs_idx = None - return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=last_logprobs, - output_top_logprobs_val=output_top_logprobs_val, - output_top_logprobs_idx=output_top_logprobs_idx, - ) + # Get the logprob of top-k tokens + if logits_metadata.extend_return_top_logprob: + ( + input_top_logprobs_val, + input_top_logprobs_idx, + ) = self.get_top_logprobs(input_logprobs, logits_metadata) else: - # Slice the requested tokens to compute logprob - pt, states, pruned_input_ids = 0, [], [] - for start_len, extend_len in zip( - logits_metadata.extend_logprob_start_lens_cpu, - logits_metadata.extend_seq_lens_cpu, - ): - states.append(hidden_states[pt + start_len : pt + extend_len]) - pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) - pt += extend_len + input_top_logprobs_val = input_top_logprobs_idx = None - # Compute the logits and logprobs for all required tokens - states = torch.cat(states, dim=0) - all_logits = self._get_logits(states, lm_head) - if self.do_tensor_parallel_all_gather: - all_logits = tensor_model_parallel_all_gather(all_logits) + # Compute the normalized logprobs for the requested tokens. + # Note that we pad a zero at the end for easy batching. + input_token_logprobs = input_logprobs[ + torch.arange(input_logprobs.shape[0], device="cuda"), + torch.cat( + [ + torch.cat(pruned_input_ids)[1:], + torch.tensor([0], device="cuda"), + ] + ), + ] + normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( + input_token_logprobs, + logits_metadata, + ) - # The LM head's weights may be zero-padded for parallelism. Remove any - # extra logits that this padding may have produced. - all_logits = all_logits[:, : self.config.vocab_size].float() - - if self.final_logit_softcapping: - all_logits.div_(self.final_logit_softcapping) - torch.tanh(all_logits, out=all_logits) - all_logits.mul_(self.final_logit_softcapping) - - all_logprobs = all_logits - del all_logits, hidden_states - - all_logprobs = self.compute_temp_top_p_normalized_logprobs( - all_logprobs, logits_metadata - ) - - # Get the logprob of top-k tokens - if logits_metadata.return_top_logprob: - ( - input_top_logprobs_val, - input_top_logprobs_idx, - output_top_logprobs_val, - output_top_logprobs_idx, - ) = self.get_top_logprobs(all_logprobs, logits_metadata) - else: - input_top_logprobs_val = input_top_logprobs_idx = ( - output_top_logprobs_val - ) = output_top_logprobs_idx = None - - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. - input_token_logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat( - [ - torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device="cuda"), - ] - ), - ] - normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, - logits_metadata, - ) - - return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=normalized_prompt_logprobs, - input_token_logprobs=input_token_logprobs, - input_top_logprobs_val=input_top_logprobs_val, - input_top_logprobs_idx=input_top_logprobs_idx, - output_top_logprobs_val=output_top_logprobs_val, - output_top_logprobs_idx=output_top_logprobs_idx, - ) + return LogitsProcessorOutput( + next_token_logits=last_logits, + normalized_prompt_logprobs=normalized_prompt_logprobs, + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + ) def _get_logits( self, @@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module): # GGUF models logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) - # Optional scaling factor if self.logit_scale is not None: - logits.mul_(self.logit_scale) # In-place multiply + logits.mul_(self.logit_scale) + + if self.do_tensor_parallel_all_gather: + logits = tensor_model_parallel_all_gather(logits) + + # Compute the normalized logprobs for the requested tokens. + # Note that we pad a zero at the end for easy batching. + logits = logits[:, : self.config.vocab_size].float() + + if self.final_logit_softcapping: + fused_softcap(logits, self.final_logit_softcapping) + return logits @staticmethod @@ -302,90 +271,73 @@ class LogitsProcessor(nn.Module): values = ret.values.tolist() indices = ret.indices.tolist() - if logits_metadata.forward_mode.is_decode(): - output_top_logprobs_val = [] - output_top_logprobs_idx = [] - for i, k in enumerate(logits_metadata.top_logprobs_nums): - output_top_logprobs_val.append(values[i][:k]) - output_top_logprobs_idx.append(indices[i][:k]) - return None, None, output_top_logprobs_val, output_top_logprobs_idx - else: - input_top_logprobs_val, input_top_logprobs_idx = [], [] - output_top_logprobs_val, output_top_logprobs_idx = [], [] + input_top_logprobs_val, input_top_logprobs_idx = [], [] - pt = 0 - for k, pruned_len in zip( - logits_metadata.top_logprobs_nums, - logits_metadata.extend_logprob_pruned_lens_cpu, - ): - if pruned_len <= 0: - input_top_logprobs_val.append([]) - input_top_logprobs_idx.append([]) - output_top_logprobs_val.append([]) - output_top_logprobs_idx.append([]) - continue + pt = 0 + for k, pruned_len in zip( + logits_metadata.top_logprobs_nums, + logits_metadata.extend_logprob_pruned_lens_cpu, + ): + if pruned_len <= 0: + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) + continue - input_top_logprobs_val.append( - [values[pt + j][:k] for j in range(pruned_len - 1)] - ) - input_top_logprobs_idx.append( - [indices[pt + j][:k] for j in range(pruned_len - 1)] - ) - output_top_logprobs_val.append( - list( - values[pt + pruned_len - 1][:k], - ) - ) - output_top_logprobs_idx.append( - list( - indices[pt + pruned_len - 1][:k], - ) - ) - pt += pruned_len - - return ( - input_top_logprobs_val, - input_top_logprobs_idx, - output_top_logprobs_val, - output_top_logprobs_idx, + input_top_logprobs_val.append( + [values[pt + j][:k] for j in range(pruned_len - 1)] ) + input_top_logprobs_idx.append( + [indices[pt + j][:k] for j in range(pruned_len - 1)] + ) + pt += pruned_len + + return input_top_logprobs_val, input_top_logprobs_idx @staticmethod def compute_temp_top_p_normalized_logprobs( last_logits: torch.Tensor, logits_metadata: LogitsMetadata ) -> torch.Tensor: + # TODO: Implement the temp and top-p normalization return torch.nn.functional.log_softmax(last_logits, dim=-1) -def test(): - all_logprobs = torch.tensor( - # s s s - [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], - dtype=torch.float32, - device="cuda", +@triton.jit +def fused_softcap_kernel( + full_logits_ptr, + softcapping_value, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load values + x = tl.load(full_logits_ptr + offsets, mask=mask) + + # Perform operations in-place + x = x / softcapping_value + + # Manual tanh implementation using exp + exp2x = tl.exp(2 * x) + x = (exp2x - 1) / (exp2x + 1) + + x = x * softcapping_value + + # Store result + tl.store(full_logits_ptr + offsets, x, mask=mask) + + +def fused_softcap(full_logits, final_logit_softcapping): + n_elements = full_logits.numel() + BLOCK_SIZE = 1024 + grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1) + + fused_softcap_kernel[grid]( + full_logits_ptr=full_logits, + softcapping_value=final_logit_softcapping, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE, ) - seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda") - input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") - - token_logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), - ] - logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32) - - len_cumsum = torch.cumsum(seq_lens, dim=0) - start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0) - end = start + seq_lens - 2 - start.clamp_(min=0, max=token_logprobs.shape[0] - 1) - end.clamp_(min=0, max=token_logprobs.shape[0] - 1) - sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start] - - # assert logprobs == [2, _, 2, 4, _] - print("token logprobs", token_logprobs) - print("start", start) - print("end", end) - print("sum_logp", sum_logp) - - -if __name__ == "__main__": - test() + return full_logits diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 8a4dcc8ae..bed770e39 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import Union +from typing import List import torch from torch import nn @@ -28,13 +28,12 @@ class Sampler(nn.Module): def forward( self, - logits: Union[torch.Tensor, LogitsProcessorOutput], + logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], ): - if isinstance(logits, LogitsProcessorOutput): - logits = logits.next_token_logits - - logits = logits.contiguous() + logits = logits_output.next_token_logits if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") @@ -47,6 +46,8 @@ class Sampler(nn.Module): if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling batch_next_token_ids = torch.argmax(logits, -1) + if return_logprob: + logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: # Post process logits logits.div_(sampling_info.temperatures) @@ -54,6 +55,12 @@ class Sampler(nn.Module): del logits if global_server_args_dict["sampling_backend"] == "flashinfer": + if return_logprob: + # NOTE: the top_p_renorm_prob from flashinfer has numerical problems + logprobs = torch.log( + top_p_normalize_probs_torch(probs, sampling_info.top_ps) + ) + max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device @@ -76,6 +83,7 @@ class Sampler(nn.Module): if self.use_nan_detectioin and not torch.all(success): logger.warning("Detected errors during sampling!") batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations. batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( @@ -85,12 +93,31 @@ class Sampler(nn.Module): sampling_info.min_ps, sampling_info.need_min_p_sampling, ) + if return_logprob: + logprobs = torch.log( + top_p_normalize_probs_torch(probs, sampling_info.top_ps) + ) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - return batch_next_token_ids.to(torch.int32) + batch_next_token_ids = batch_next_token_ids.to(torch.int32) + + # Attach logprobs to logits_output (in-place modification) + if return_logprob: + if any(x > 0 for x in top_logprobs_nums): + ( + logits_output.next_token_top_logprobs_val, + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums) + + logits_output.next_token_logprobs = logprobs[ + torch.arange(len(batch_next_token_ids), device=sampling_info.device), + batch_next_token_ids, + ] + + return batch_next_token_ids def top_k_top_p_min_p_sampling_from_probs_torch( @@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch( return batch_next_token_ids -def top_p_normalize_probs( +def top_p_normalize_probs_torch( probs: torch.Tensor, top_ps: torch.Tensor, ): - if global_server_args_dict["sampling_backend"] == "flashinfer": - return top_p_renorm_prob(probs, top_ps) - elif global_server_args_dict["sampling_backend"] == "pytorch": - # See also top_k_top_p_min_p_sampling_from_probs_torch - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) - else: - raise ValueError( - f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" - ) + # See also top_k_top_p_min_p_sampling_from_probs_torch + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + + +def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): + max_k = max(top_logprobs_nums) + ret = logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + for i, k in enumerate(top_logprobs_nums): + output_top_logprobs_val.append(values[i][:k]) + output_top_logprobs_idx.append(indices[i][:k]) + return output_top_logprobs_val, output_top_logprobs_idx diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3abaa1a6c..4bf41aaf3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -974,12 +974,10 @@ class Scheduler: logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) else: # Move next_token_ids and logprobs to cpu + next_token_ids = next_token_ids.tolist() if batch.return_logprob: logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].tolist() + logits_output.next_token_logprobs.tolist() ) logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() @@ -987,7 +985,6 @@ class Scheduler: logits_output.normalized_prompt_logprobs = ( logits_output.normalized_prompt_logprobs.tolist() ) - next_token_ids = next_token_ids.tolist() # Check finish conditions logprob_pt = 0 @@ -1064,13 +1061,9 @@ class Scheduler: logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) next_token_logprobs = logits_output.next_token_logprobs else: - # Move next_token_ids and logprobs to cpu - if batch.return_logprob: - next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].tolist() next_token_ids = next_token_ids.tolist() + if batch.return_logprob: + next_token_logprobs = logits_output.next_token_logprobs.tolist() self.token_to_kv_pool.free_group_begin() @@ -1095,10 +1088,10 @@ class Scheduler: req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: req.output_top_logprobs_val.append( - logits_output.output_top_logprobs_val[i] + logits_output.next_token_top_logprobs_val[i] ) req.output_top_logprobs_idx.append( - logits_output.output_top_logprobs_idx[i] + logits_output.next_token_top_logprobs_idx[i] ) if req.grammar is not None: @@ -1200,8 +1193,9 @@ class Scheduler: req.output_top_logprobs_idx.extend( output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] ) - req.output_top_logprobs_val.append(output.output_top_logprobs_val[i]) - req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i]) + + req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) return num_input_logprobs diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 4600bf99a..4c98c6be2 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -144,10 +144,9 @@ class TpModelWorkerClient: # Copy results to the CPU if model_worker_batch.return_logprob: - logits_output.next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].to("cpu", non_blocking=True) + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.to("cpu", non_blocking=True) + ) if logits_output.input_token_logprobs is not None: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.to("cpu", non_blocking=True) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 65418c703..a9c2c3781 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -392,34 +392,7 @@ class CudaGraphRunner: self.graphs[bs].replay() next_token_logits = self.output_buffers[bs][:raw_bs] - # Extract logprobs - if forward_batch.return_logprob: - logits_metadata = LogitsMetadata( - forward_mode=ForwardMode.DECODE, - top_logprobs_nums=forward_batch.top_logprobs_nums, - ) - next_token_logprobs = ( - LogitsProcessor.compute_temp_top_p_normalized_logprobs( - next_token_logits, logits_metadata - ) - ) - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits, - next_token_logprobs=next_token_logprobs, - ) - return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) - if return_top_logprob: - ( - logits_output.output_top_logprobs_val, - logits_output.output_top_logprobs_idx, - ) = LogitsProcessor.get_top_logprobs( - next_token_logprobs, logits_metadata - )[ - 2:4 - ] - else: - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits, - ) - + logits_output = LogitsProcessorOutput( + next_token_logits=next_token_logits, + ) return logits_output diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7cb7d5da7..67640947a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.sampler import Sampler, get_top_logprobs from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model -from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( enable_show_time_cost, @@ -192,7 +191,8 @@ class ModelRunner: torch.get_device_module(self.device).set_device(self.gpu_id) if self.device == "cuda": backend = "nccl" - # ToDO(liangan1):Just use gloo to bypass the initilization fail + + # TODO(liangan1):Just use gloo to bypass the initilization fail # Need to use xccl for xpu backend in the future elif self.device == "xpu": backend = "gloo" @@ -704,6 +704,7 @@ class ModelRunner: def sample( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ) -> torch.Tensor: + # Apply logit bias sampling_info = forward_batch.sampling_info if sampling_info.sampling_info_done: # Overlap mode: the function update_regex_vocab_mask was executed @@ -714,35 +715,17 @@ class ModelRunner: # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info.update_regex_vocab_mask() sampling_info.update_penalties() - logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) + sampling_info.apply_logits_bias(logits_output.next_token_logits) - # Sample the next tokens. - next_token_ids = self.sampler(logits, sampling_info) + # Sample the next tokens + next_token_ids = self.sampler( + logits_output, + sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, + ) return next_token_ids - def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): - # Apply logit_bias - if sampling_info.logit_bias is not None: - logits.add_(sampling_info.logit_bias) - - # min-token, presence, frequency - if sampling_info.linear_penalties is not None: - logits.add_(sampling_info.linear_penalties) - - # repetition - if sampling_info.scaling_penalties is not None: - logits = torch.where( - logits > 0, - logits / sampling_info.scaling_penalties, - logits * sampling_info.scaling_penalties, - ) - - # Apply regex vocab_mask - if sampling_info.vocab_mask is not None: - sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask) - - return logits - @property def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type. diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index a64a84a62..5d4aaa41b 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -232,3 +232,26 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + + def apply_logits_bias(self, logits: torch.Tensor): + # Apply logit_bias + if self.logit_bias is not None: + logits.add_(self.logit_bias) + + # min-token, presence, frequency + if self.linear_penalties is not None: + logits.add_(self.linear_penalties) + + # repetition + if self.scaling_penalties is not None: + logits = torch.where( + logits > 0, + logits / self.scaling_penalties, + logits * self.scaling_penalties, + ) + + # Apply regex vocab_mask + if self.vocab_mask is not None: + self.apply_mask(logits=logits, vocab_mask=self.vocab_mask) + + return logits diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 0eccb3407..5245905f7 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -6,7 +6,7 @@ import requests from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index aff1d4a78..e974821b1 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase): max_diff = np.max(diff) self.assertLess(max_diff, 0.25) + def test_logprob_grammar(self): + prompts = "Question: Is Paris the Capital of France? Answer:" + allowed_tokens = [" Yes", " No"] + + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompts, + "sampling_params": { + "temperature": 1.0, + "max_new_tokens": 1, + "regex": "( Yes| No)", + }, + "return_logprob": True, + "top_logprobs_num": 5, + "return_text_in_logprobs": True, + }, + ) + response_json = response.json() + output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0] + print(f"{output_top_logprobs=}") + + # Parse results + # This is becaues the grammar constraint allows all prefix tokens + logprobs = [None] * 2 + for i in range(len(output_top_logprobs)): + try: + idx = allowed_tokens.index(output_top_logprobs[i][2]) + except ValueError: + # Not found + continue + logprobs[idx] = output_top_logprobs[i][0] + + self.assertTrue(all(x is not None for x in logprobs)) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json()