diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 90f981c57..3384f5efa 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -83,6 +83,7 @@ class LogitsProcessorOutput: class LogitsMetadata: forward_mode: ForwardMode capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL + next_token_logits_buffer: Optional[torch.Tensor] = None extend_return_logprob: bool = False extend_return_top_logprob: bool = False @@ -148,6 +149,7 @@ class LogitsMetadata: return cls( forward_mode=forward_batch.forward_mode, capture_hidden_mode=forward_batch.capture_hidden_mode, + next_token_logits_buffer=forward_batch.next_token_logits_buffer, extend_return_logprob=extend_return_logprob, extend_return_top_logprob=extend_return_top_logprob, extend_token_ids_logprob=extend_token_ids_logprob, @@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module): ) dp_scatter(logits, global_logits, logits_metadata) - logits = logits[:, : self.config.vocab_size].float() + if logits_metadata.next_token_logits_buffer is not None: + logits_buffer = logits_metadata.next_token_logits_buffer + assert logits_buffer.dtype == torch.float + logits_buffer.copy_(logits[:, : self.config.vocab_size]) + logits = logits_buffer + else: + logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: fused_softcap(logits, self.final_logit_softcapping) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e5a8cc872..39120f2cd 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -375,6 +375,11 @@ class CudaGraphRunner: dtype=torch.bool, device="cuda", ) + self.next_token_logits_buffer = torch.zeros( + (self.max_num_token, self.model_runner.model_config.vocab_size), + dtype=torch.float, + device="cuda", + ) # Capture try: @@ -520,6 +525,7 @@ class CudaGraphRunner: else: encoder_lens = None mrope_positions = self.mrope_positions[:, :bs] + next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens] self.num_token_non_padded[...] = num_tokens # pipeline parallelism @@ -582,6 +588,7 @@ class CudaGraphRunner: input_ids=input_ids, req_pool_indices=req_pool_indices, seq_lens=seq_lens, + next_token_logits_buffer=next_token_logits_buffer, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 92eeb6860..5f8cc0ed4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -189,6 +189,7 @@ class ForwardBatch: token_ids_logprobs: Optional[List[List[int]]] = None # For logits and logprobs post processing + next_token_logits_buffer: torch.Tensor = None temp_scaled_logprobs: bool = False temperature: torch.Tensor = None top_p_normalized_logprobs: bool = False diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index f4ed31d7e..08d823a0b 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner: self.global_num_tokens_for_logprob_gpu = None self.gathered_buffer = None + if hasattr( + self.model_runner.model_config.hf_config, "draft_vocab_size" + ): # llama_eagle + vocab_size = self.model_runner.model_config.hf_config.draft_vocab_size + elif hasattr( + self.model_runner.model_config.hf_config, "hot_vocab_size" + ): # llama_eagle3 + vocab_size = self.model_runner.model_config.hf_config.hot_vocab_size + else: + vocab_size = self.model_runner.model_config.vocab_size + + self.next_token_logits_buffer = torch.zeros( + (self.max_bs, vocab_size), + dtype=torch.float, + ) + # Capture try: with model_capture_mode(): @@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner: out_cache_loc = self.out_cache_loc[:num_tokens] positions = self.positions[:num_tokens] hidden_states = self.hidden_states[:num_tokens] + next_token_logits_buffer = self.next_token_logits_buffer[:bs] if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( @@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner: input_ids=input_ids, req_pool_indices=req_pool_indices, seq_lens=seq_lens, + next_token_logits_buffer=next_token_logits_buffer, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 2babeefc1..eea5623dc 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -564,6 +564,7 @@ class TboForwardBatchPreparer: mm_inputs=None, top_logprobs_nums=None, token_ids_logprobs=None, + next_token_logits_buffer=None, ) )