diff --git a/python/pyproject.toml b/python/pyproject.toml index a1a0dcdf5..babb31d4e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -33,6 +33,7 @@ runtime_common = [ "prometheus-client>=0.20.0", "psutil", "pydantic", + "pynvml", "python-multipart", "pyzmq>=25.1.2", "soundfile==0.13.1", diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index fba806010..6f7ed8523 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -14,7 +14,6 @@ from functools import partial from typing import TYPE_CHECKING, Callable, List, Optional, Union import torch -import triton from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import get_bool_env_var, is_flashinfer_available +from sglang.srt.utils import is_flashinfer_available, next_power_of_2 if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend: self.topk = topk self.speculative_num_steps = speculative_num_steps self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + self.page_size = model_runner.page_size max_bs = model_runner.req_to_token_pool.size * self.topk self.kv_indptr = torch.zeros( @@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend: self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], - triton.next_power_of_2(num_seqs), - triton.next_power_of_2(self.speculative_num_steps), - triton.next_power_of_2(bs), + next_power_of_2(num_seqs), + next_power_of_2(self.speculative_num_steps), + next_power_of_2(bs), ) assert forward_batch.spec_info is not None @@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend: ) def call_fn(i, forward_batch): - assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr.clone() ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 45e1d4be2..dfdfef662 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) return req_pool_indices - def alloc_token_slots(self, num_tokens: int): + def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): if self.token_to_kv_pool_allocator.available_size() < num_tokens: if self.tree_cache is not None: self.tree_cache.evict(num_tokens) + if backup_state: + state = self.token_to_kv_pool_allocator.backup_state() + out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) if out_cache_loc is None: phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" @@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.tree_cache.pretty_print() raise RuntimeError(error_msg) - return out_cache_loc + if backup_state: + return out_cache_loc, state + else: + return out_cache_loc def alloc_paged_token_slots_extend( self, @@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens: torch.Tensor, last_loc: torch.Tensor, extend_num_tokens: int, + backup_state: bool = False, ): if ( self.token_to_kv_pool_allocator.available_size() @@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + len(seq_lens) * self.token_to_kv_pool_allocator.page_size, ) + if backup_state: + state = self.token_to_kv_pool_allocator.backup_state() + out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( prefix_lens, seq_lens, last_loc, extend_num_tokens ) @@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) logger.error(error_msg) raise RuntimeError(error_msg) - return out_cache_loc + + if backup_state: + return out_cache_loc, state + else: + return out_cache_loc def alloc_paged_token_slots_decode( self, seq_lens: torch.Tensor, last_loc: torch.Tensor, + backup_state: bool = False, ): if ( self.token_to_kv_pool_allocator.available_size() @@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.tree_cache.evict( len(seq_lens) * self.token_to_kv_pool_allocator.page_size, ) - out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) + if backup_state: + state = self.token_to_kv_pool_allocator.backup_state() + + out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) if out_cache_loc is None: error_msg = ( f"Decode out of memory. Try to lower your batch size.\n" @@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) logger.error(error_msg) raise RuntimeError(error_msg) - return out_cache_loc + + if backup_state: + return out_cache_loc, state + else: + return out_cache_loc def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 63acbe2aa..b1431bf40 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1110,7 +1110,7 @@ class Scheduler( ) if memory_leak: msg = ( - "KV cache pool leak detected! " + "token_to_kv_pool_allocator memory leak detected! " f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" f"{self.token_to_kv_pool_allocator.available_size()=}\n" f"{self.tree_cache.evictable_size()=}\n" @@ -1121,7 +1121,7 @@ class Scheduler( if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: msg = ( - "Memory pool leak detected!" + "req_to_token_pool memory leak detected!" f"available_size={len(self.req_to_token_pool.free_slots)}, " f"total_size={self.req_to_token_pool.size}\n" ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index e425c7927..ae50df114 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -185,6 +185,12 @@ class TokenToKVPoolAllocator: if self.free_group: self.free(torch.cat(self.free_group)) + def backup_state(self): + return self.free_slots + + def restore_state(self, free_slots): + self.free_slots = free_slots + def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.free_slots = torch.arange( diff --git a/python/sglang/srt/mem_cache/paged_allocator.py b/python/sglang/srt/mem_cache/paged_allocator.py index 6b7402e06..37c4a8e5e 100644 --- a/python/sglang/srt/mem_cache/paged_allocator.py +++ b/python/sglang/srt/mem_cache/paged_allocator.py @@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator: next_power_of_2(extend_num_tokens), ) + if self.debug_mode: + assert len(torch.unique(out_indices)) == len(out_indices) + merged_value = self.ret_values.item() num_new_pages = merged_value >> 32 if num_new_pages > len(self.free_pages): @@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator: self.page_size, ) + if self.debug_mode: + assert len(torch.unique(out_indices)) == len(out_indices) + num_new_pages = self.ret_values.item() if num_new_pages > len(self.free_pages): return None @@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator: else: self.free_group.append(free_index) + if self.debug_mode: + assert len(torch.unique(self.free_pages)) == len(self.free_pages) + def free_group_begin(self): self.is_not_in_free_group = False self.free_group = [] @@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator: if self.free_group: self.free(torch.cat(self.free_group)) + def backup_state(self): + return self.free_pages + + def restore_state(self, free_pages): + self.free_pages = free_pages + def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.free_pages = torch.arange( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index f5ac35d40..1039ddaae 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if capture_bs is None: if server_args.speculative_algorithm is None: if server_args.disable_cuda_graph_padding: - capture_bs = list(range(1, 33)) + [64, 96, 128, 160] + capture_bs = list(range(1, 33)) + range(40, 161, 16) else: - capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) else: # Since speculative decoding requires more cuda graph memory, we # capture less. - capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160] + capture_bs = ( + list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16)) + ) if _is_hip: - capture_bs += [i * 8 for i in range(21, 33)] + capture_bs += list(range(160, 257, 8)) if max(capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 2c7f53065..5d2e2bc82 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -17,7 +17,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f8d5512cb..67a155f8f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -15,6 +15,7 @@ import argparse import dataclasses +import json import logging import os import random @@ -132,9 +133,9 @@ class ServerArgs: # Speculative decoding speculative_algorithm: Optional[str] = None speculative_draft_model_path: Optional[str] = None - speculative_num_steps: int = 5 - speculative_eagle_topk: int = 4 - speculative_num_draft_tokens: int = 8 + speculative_num_steps: Optional[int] = None + speculative_eagle_topk: Optional[int] = None + speculative_num_draft_tokens: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -313,12 +314,29 @@ class ServerArgs: or self.speculative_algorithm == "EAGLE3" ): if self.max_running_requests is None: - self.max_running_requests = 32 + self.max_running_requests = 48 self.disable_overlap_schedule = True logger.info( "Overlap scheduler is disabled because of using " "eagle speculative decoding." ) + + # Auto choose parameters + if self.speculative_num_steps is None: + assert ( + self.speculative_eagle_topk is None + and self.speculative_num_draft_tokens is None + ) + ( + self.speculative_num_steps, + self.speculative_eagle_topk, + self.speculative_num_draft_tokens, + ) = auto_choose_speculative_params(self) + + if self.page_size > 1 and self.speculative_eagle_topk > 1: + self.speculative_eagle_topk = 1 + logger.info("speculative_eagle_topk is changed to 1 when page_size > 1") + # The token generated from the verify step is counted. # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. # assert self.speculative_num_steps < self.speculative_num_draft_tokens @@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): raise ValueError(self.help) + + +def auto_choose_speculative_params(self: ServerArgs): + """ + Automatically choose the parameters for speculative decoding. + + You can tune them on your own models and prompts with scripts/playground/bench_speculative.py + """ + if self.decrypted_config_file: + config_path = self.decrypted_config_file + else: + config_path = os.path.join(self.model_path, "config.json") + if not os.path.exists(config_path): + raise ValueError(f"{config_path} is not found.") + + config = json.load(open(config_path)) + + arch = config.get("architectures", ["Unknown"])[0] + + if arch in ["LlamaForCausalLM"]: + # The default value for llama + return (5, 4, 8) + elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]: + # The default value for deepseek + return (5, 4, 8) + elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]: + return (5, 4, 8) + else: + # The default value for all other models + return (5, 4, 8) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 0c5b9b4a5..19fa1807c 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional @@ -10,11 +11,15 @@ import triton.language as tl from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import ( + ScheduleBatch, + get_last_loc, + global_server_args_dict, +) from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient -from sglang.srt.utils import is_cuda_available, is_hip +from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2 if is_cuda_available(): from sgl_kernel import ( @@ -34,6 +39,9 @@ import logging logger = logging.getLogger(__name__) +SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN") + + @dataclass class EagleDraftInput: # The inputs for decode @@ -93,7 +101,7 @@ class EagleDraftInput: torch.cumsum(self.accept_length, axis=0, dtype=torch.int), self.positions, new_verified_id, - triton.next_power_of_2(speculative_num_steps + 1), + next_power_of_2(speculative_num_steps + 1), ) batch.seq_lens_sum = sum(seq_lens_cpu) @@ -225,18 +233,34 @@ class EagleVerifyInput: CaptureHiddenMode.FULL, ) - def prepare_for_verify(self, batch: ScheduleBatch): + def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): batch.input_ids = self.draft_token - batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + + if page_size == 1: + batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + end_offset = prefix_lens + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = batch.alloc_paged_token_slots_extend( + prefix_lens, end_offset, last_loc, len(batch.input_ids) + ) + self.last_loc = last_loc + bs = batch.batch_size() assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, - batch.seq_lens + self.draft_token_num, + end_offset, batch.out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), + next_power_of_2(bs), ) def generate_attn_arg_prefill( @@ -282,6 +306,7 @@ class EagleVerifyInput: batch: ScheduleBatch, logits_output: torch.Tensor, token_to_kv_pool_allocator: TokenToKVPoolAllocator, + page_size: int, ) -> torch.Tensor: """ Verify and find accepted tokens based on logits output and batch @@ -305,6 +330,7 @@ class EagleVerifyInput: ) accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") + # Apply penalty if sampling_info.penalizer_orchestrator.is_required: # This is a relaxed version of penalties for speculative decoding. linear_penalty = torch.zeros( @@ -317,6 +343,7 @@ class EagleVerifyInput: torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) + # Sample tokens if batch.sampling_info.is_all_greedy: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num) @@ -378,13 +405,24 @@ class EagleVerifyInput: deterministic=True, ) + if SIMULATE_ACC_LEN: + # Do simulation + accept_index = _generate_simulated_accept_index( + accept_index=accept_index, + predict=predict, # mutable + accept_length=accept_length, # mutable + simulate_acc_len=SIMULATE_ACC_LEN, + bs=bs, + spec_steps=self.spec_steps, + ) + new_accept_index = [] unfinished_index = [] accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() has_finished = False - # iterate every accepted token and check if req has finished after append the token + # Iterate every accepted token and check if req has finished after append the token # should be checked BEFORE free kv cache slots for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): new_accept_index_ = [] @@ -407,13 +445,28 @@ class EagleVerifyInput: unfinished_index.append(i) req.spec_verify_ct += 1 + if has_finished: + accept_length = (accept_index != -1).sum(dim=1) - 1 + + # Free the KV cache for unaccepted tokens + accept_index = accept_index[accept_index != -1] + verified_id = predict[accept_index] + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + + if page_size != 1: + align_evict_mask_to_page_size[len(batch.seq_lens),]( + batch.seq_lens, + evict_mask, + page_size, + self.draft_token_num, + next_power_of_2(self.draft_token_num), + ) + + token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) + + # Construct EagleVerifyOutput if not has_finished: - accept_index = accept_index[accept_index != -1] - verified_id = predict[accept_index] - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[accept_index] = False - mem_need_free_idx = batch.out_cache_loc[evict_mask] - token_to_kv_pool_allocator.free(mem_need_free_idx) batch.out_cache_loc = batch.out_cache_loc[accept_index] assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, @@ -422,7 +475,7 @@ class EagleVerifyInput: batch.seq_lens + accept_length + 1, batch.out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), + next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() @@ -443,13 +496,6 @@ class EagleVerifyInput: accepeted_indices=accept_index, ) else: - accept_length = (accept_index != -1).sum(dim=1) - 1 - accept_index = accept_index[accept_index != -1] - verified_id = predict[accept_index] - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[accept_index] = False - mem_need_free_idx = batch.out_cache_loc[evict_mask] - token_to_kv_pool_allocator.free(mem_need_free_idx) assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -457,7 +503,7 @@ class EagleVerifyInput: batch.seq_lens + accept_length + 1, batch.out_cache_loc[accept_index], batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), + next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() @@ -465,20 +511,21 @@ class EagleVerifyInput: draft_input = EagleDraftInput() if len(new_accept_index) > 0: new_accept_index = torch.tensor(new_accept_index, device="cuda") + unfinished_index_device = torch.tensor(unfinished_index, device="cuda") draft_input.hidden_states = batch.spec_info.hidden_states[ new_accept_index ] draft_input.verified_id = predict[new_accept_index] - draft_input.accept_length = accept_length[unfinished_index] draft_input.accept_length_cpu = [ accept_length_cpu[i] for i in unfinished_index ] + draft_input.accept_length = accept_length[unfinished_index_device] if has_finished: draft_input.seq_lens_for_draft_extend = batch.seq_lens[ - unfinished_index + unfinished_index_device ] draft_input.req_pool_indices_for_draft_extend = ( - batch.req_pool_indices[unfinished_index] + batch.req_pool_indices[unfinished_index_device] ) else: draft_input.seq_lens_for_draft_extend = batch.seq_lens @@ -564,13 +611,24 @@ def assign_draft_cache_locs( pool_len: tl.constexpr, topk: tl.constexpr, speculative_num_steps: tl.constexpr, + page_size: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 32 pid = tl.program_id(axis=0) kv_start = tl.load(seq_lens + pid) - kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps + + if page_size == 1 or topk == 1: + kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + else: + prefix_len = tl.load(seq_lens + pid) + last_page_len = prefix_len % page_size + num_new_page = ( + last_page_len + speculative_num_steps + page_size - 1 + ) // page_size + kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len - out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) for i in range(num_loop): @@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices( tl.store(kv_indptr + zid, base + zid * iters) +@triton.jit +def align_evict_mask_to_page_size( + seq_lens, + evict_mask, + page_size: tl.constexpr, + num_draft_tokens: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + t_range = tl.arange(0, BLOCK_SIZE) + + bid = tl.program_id(axis=0) + seq_len = tl.load(seq_lens + bid) + io_mask = t_range < num_draft_tokens + mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask) + + num_trues = tl.sum(mask_row) + num_false = num_draft_tokens - num_trues + + start = (seq_len + num_false - 1) // page_size * page_size - seq_len + for i in range(max(start, 0), min(start + page_size, num_draft_tokens)): + tl.store(evict_mask + bid * num_draft_tokens + i, False) + + @torch.compile(dynamic=True) def select_top_k_tokens( i: int, @@ -699,3 +780,34 @@ def fast_topk(values, topk, dim): else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + + +def _generate_simulated_accept_index( + accept_index, + predict, + accept_length, + simulate_acc_len, + bs, + spec_steps, +): + simulate_acc_len_float = float(simulate_acc_len) + simulated_values = torch.normal( + mean=simulate_acc_len_float, + std=1.0, + size=(1,), + device="cpu", + ) + # clamp simulated values to be between 1 and self.spec_steps + simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps) + simulate_acc_len = int(simulated_values.round().item()) + + accept_indx_first_col = accept_index[:, 0].view(-1, 1) + sim_accept_index = torch.full( + (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda" + ) + sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange( + simulate_acc_len, device=accept_index.device + ) + accept_length.fill_(simulate_acc_len - 1) + predict.fill_(100) # some legit token id + return sim_accept_index diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 598e5ac4a..234e13209 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group from sglang.srt.layers.dp_attention import disable_dp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker): self.gpu_id = gpu_id self.device = server_args.device self.target_worker = target_worker + self.page_size = server_args.page_size self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) @@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker): """ if batch.forward_mode.is_decode(): with self.draft_tp_context(self.draft_model_runner.tp_group): - spec_info, to_free_cache_loc = self.draft(batch) + spec_info = self.draft(batch) logits_output, verify_output, model_worker_batch = self.verify( batch, spec_info ) - # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.) - self.token_to_kv_pool_allocator.free(to_free_cache_loc) - # If it is None, it means all requests are finished if batch.spec_info.verified_id is not None: with self.draft_tp_context(self.draft_model_runner.tp_group): @@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker): ) # Allocate cache locations - out_cache_loc = batch.alloc_token_slots( - num_seqs * self.topk * self.speculative_num_steps - ) + if self.page_size == 1: + out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( + num_seqs * self.topk * self.speculative_num_steps, backup_state=True + ) + else: + if self.topk == 1: + prefix_lens = batch.seq_lens + seq_lens = prefix_lens + self.speculative_num_steps + extend_num_tokens = num_seqs * self.speculative_num_steps + else: + # In this case, the last partial page needs to be duplicated. + # KV cache layout in batch.req_to_token_pool.req_to_token: + # + # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. | + # prefix top-k = 0 tok-k = 1 top-k = 2 + # + # "-" means prefix tokens + # "x" means speculative draft tokens + # "." means padded tokens + + # TODO: fuse these ops + prefix_lens = batch.seq_lens + last_page_lens = prefix_lens % self.page_size + num_new_pages = ( + last_page_lens + self.speculative_num_steps + self.page_size - 1 + ) // self.page_size + seq_lens = ( + prefix_lens // self.page_size * self.page_size + + num_new_pages * (self.page_size * self.topk) + ) + extend_num_tokens = torch.sum(seq_lens - prefix_lens).item() + raise NotImplementedError( + "page_size > 1 and top_k > 1 are not supported." + ) + # TODO: Support page_size > 1 and top_k > 1 + # 1. Duplicate the KV cache in the last partial page for all top-k segments + # 2. Modify generate_draft_decode_kv_indices accordingly + + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + out_cache_loc, token_to_kv_pool_state_backup = ( + batch.alloc_paged_token_slots_extend( + prefix_lens, + seq_lens, + last_loc, + extend_num_tokens, + backup_state=True, + ) + ) + assign_draft_cache_locs[(num_seqs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker): batch.req_to_token_pool.req_to_token.shape[1], self.topk, self.speculative_num_steps, + self.page_size, ) batch.out_cache_loc = out_cache_loc batch.seq_lens_sum = torch.sum(batch.seq_lens).item() @@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker): # Run forward steps score_list, token_list, parents_list = self.draft_forward(forward_batch) + self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) + ret = EagleVerifyInput.create( spec_info.verified_id, score_list, @@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker): self.speculative_num_steps, self.server_args.speculative_num_draft_tokens, ) - return ret, out_cache_loc + return ret def draft_forward(self, forward_batch: ForwardBatch): # Parse args @@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker): return score_list, token_list, parents_list def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): - spec_info.prepare_for_verify(batch) + spec_info.prepare_for_verify(batch, self.page_size) batch.forward_mode = ForwardMode.TARGET_VERIFY batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch() @@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker): self._detect_nan_if_needed(logits_output) spec_info.hidden_states = logits_output.hidden_states res: EagleVerifyOutput = spec_info.verify( - batch, logits_output, self.token_to_kv_pool_allocator + batch, + logits_output, + self.token_to_kv_pool_allocator, + self.page_size, ) # Post process based on verified outputs. diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 6eb5c663c..4c432ae70 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -76,11 +76,14 @@ def is_in_ci(): if is_in_ci(): - DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157 - DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157" + DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100 + ) else: - DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157 - DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157" + DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100 + ) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): @@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple): class CustomTestCase(unittest.TestCase): + pass + + """ def _callTestMethod(self, method): max_retry = int( os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0") @@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase): lambda: super(CustomTestCase, self)._callTestMethod(method), max_retry=max_retry, ) + """ diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 3f397aea7..5672e380c 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei pip install sgl-kernel==0.0.5.post4 --force-reinstall pip install torch_memory_saver -pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm +pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm torchaudio # For compling xgrammar kernels pip install cuda-python nvidia-cuda-nvrtc-cu12 diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 625d5518e..628e3946d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -26,7 +26,7 @@ suites = { TestFile("test_abort.py", 51), TestFile("test_block_int8.py", 22), TestFile("test_chunked_prefill.py", 336), - TestFile("test_eagle_infer.py", 447), + TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), TestFile("test_fp8_kernel.py", 2), TestFile("test_embedding_openai_server.py", 36), diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index e29b097c8..48d0d908b 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase): print(f"{metrics=}") self.assertGreater(metrics["accuracy"], 0.20) - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + server_info = requests.get(self.base_url + "/get_server_info").json() + avg_spec_accept_length = server_info["avg_spec_accept_length"] print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 3.5) + + speculative_eagle_topk = server_info["speculative_eagle_topk"] + + if speculative_eagle_topk == 1: + self.assertGreater(avg_spec_accept_length, 2.5) + else: + self.assertGreater(avg_spec_accept_length, 3.5) # Wait a little bit so that the memory check happens. time.sleep(4) @@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer): ) +class TestEAGLEServerPageSize(TestEAGLEServer): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + 5, + "--speculative-eagle-topk", + 1, + "--speculative-num-draft-tokens", + 6, + "--mem-fraction-static", + 0.7, + "--chunked-prefill-size", + 128, + "--max-running-requests", + 8, + "--page-size", + 4, + ], + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 4fa3eb58f..81de94d3a 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) server_info = requests.get(self.base_url + "/get_server_info") + print(f"{server_info=}") avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.5)