diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 8336af2aa..49ca46a99 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner: req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens.sum(), + seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, spec_algorithm=self.model_runner.spec_algorithm, @@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner: forward_batch.positions = self.positions[:num_tokens] # Special handle for seq_len_cpu used when flashinfer mla is used - if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs): + if forward_batch.seq_lens_cpu is not None and bs != raw_bs: self.seq_lens_cpu.fill_(1) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 6894d4df2..d6313ca40 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner: req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens.sum(), + seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, spec_algorithm=self.model_runner.spec_algorithm, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 23fa1a2ed..389eb7442 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,8 +1,10 @@ from __future__ import annotations +import logging import os +import time from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -12,6 +14,7 @@ import triton.language as tl from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ( Req, ScheduleBatch, @@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import ( ) from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 @@ -34,15 +36,15 @@ if is_cuda(): elif is_hip(): from sgl_kernel import verify_tree_greedy -if TYPE_CHECKING: - from sglang.srt.managers.schedule_batch import ScheduleBatch - -import logging logger = logging.getLogger(__name__) +# Simulate acceptance length for benchmarking purposes SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN") +SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial") + +TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly @dataclass @@ -84,9 +86,9 @@ class EagleDraftInput: self, batch: ScheduleBatch, speculative_num_steps: int, + context_length: int, pad_input: bool = False, ): - assert len(self.verified_id) == len(batch.out_cache_loc) accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) @@ -112,49 +114,49 @@ class EagleDraftInput: batch.input_ids = self.verified_id self.verified_id = new_verified_id - if pad_input: - batch_size = sum(not req.finished() for req in batch.reqs) - # Total constant input length after padding - static_len = speculative_num_steps + 1 - # Total size after padding - padded_input_size = batch_size * static_len + if not pad_input: + return - padded_len = padded_input_size - batch.input_ids.shape[0] - if padded_len > 0: - new_input_ids = torch.nn.functional.pad( - batch.input_ids, (0, padded_len), value=0 - ) - position_padding = torch.arange( - padded_len, device=self.positions.device - ) - new_positions = torch.cat([self.positions, position_padding]) + batch_size = sum(not req.finished() for req in batch.reqs) + # Total constant input length after padding + static_len = speculative_num_steps + 1 + # Total size after padding + padded_input_size = batch_size * static_len - # need dummy hidden states for the padded positions - hidden_states_dim = self.hidden_states.shape[-1] - new_hidden_states = torch.cat( - [ - self.hidden_states, - torch.zeros( - (padded_len, hidden_states_dim), - dtype=self.hidden_states.dtype, - device=self.hidden_states.device, - ), - ], - dim=0, - ) + padded_len = padded_input_size - batch.input_ids.shape[0] + if padded_len > 0: + new_input_ids = torch.nn.functional.pad( + batch.input_ids, (0, padded_len), value=0 + ) + position_padding = torch.arange(padded_len, device=self.positions.device) + new_positions = torch.cat([self.positions, position_padding]) - # allocate KV cache location for the padded tokens - padded_cache_loc = torch.zeros( - padded_len, - dtype=batch.out_cache_loc.dtype, - device=batch.out_cache_loc.device, - ) - new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc]) + # need dummy hidden states for the padded positions + hidden_states_dim = self.hidden_states.shape[-1] + new_hidden_states = torch.cat( + [ + self.hidden_states, + torch.zeros( + (padded_len, hidden_states_dim), + dtype=self.hidden_states.dtype, + device=self.hidden_states.device, + ), + ], + dim=0, + ) - batch.input_ids = new_input_ids - self.hidden_states = new_hidden_states - self.positions = new_positions - batch.out_cache_loc = new_out_cache_loc + # allocate KV cache location for the padded tokens + padded_cache_loc = torch.zeros( + padded_len, + dtype=batch.out_cache_loc.dtype, + device=batch.out_cache_loc.device, + ) + new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc]) + + batch.input_ids = new_input_ids + self.hidden_states = new_hidden_states + self.positions = new_positions + batch.out_cache_loc = new_out_cache_loc def generate_attn_arg_prefill( self, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index a9193150b..af54a8619 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker): batch.spec_info.prepare_extend_after_decode( batch, self.speculative_num_steps, + self.server_args.context_length, pad_input=self.cuda_graph_runner_for_draft_extend is not None, ) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 7662ca333..1caf447ec 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -23,6 +23,7 @@ from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_launch_server, run_logprob_check, ) @@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer): ) +@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") class TestEAGLEDraftExtend(CustomTestCase): @classmethod def setUpClass(cls): @@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend): cls.accept_len_threshold = 1.50 +@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend): @classmethod def setUpClass(cls): @@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend): cls.accept_len_threshold = 1.50 +@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend): @classmethod def setUpClass(cls):