diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 710250b2e..c9ea1105b 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -92,7 +92,7 @@ jobs: python3 test_data_parallelism.py - name: Evaluate MLA accuracy (TP=2) - timeout-minutes: 20 + timeout-minutes: 10 run: | cd test/srt python3 test_mla.py diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f6d43b671..bd4b41983 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module): # Compute logits last_logits = self._get_logits(last_hidden, lm_head) - if not logits_metadata.extend_return_logprob: + if ( + not logits_metadata.extend_return_logprob + or logits_metadata.capture_hidden_mode.need_capture() + ): # Decode mode or extend mode without return_logprob. return LogitsProcessorOutput( next_token_logits=last_logits, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b78d205f2..64cc9de10 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch import dataclasses import logging -from typing import List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs +if TYPE_CHECKING: + from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm + + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -565,9 +571,13 @@ class ScheduleBatch: # Has grammar has_grammar: bool = False - # device + # Device device: str = "cuda" + # Speculative decoding + spec_info: Optional[SpecInfo] = None + spec_algorithm: Optional[SpeculativeAlgorithm] = None + @classmethod def init_new( cls, @@ -577,6 +587,7 @@ class ScheduleBatch: tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, + speculative_algorithm: Optional[SpeculativeAlgorithm] = None, ): return cls( reqs=reqs, @@ -589,6 +600,7 @@ class ScheduleBatch: has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, + spec_algorithm=speculative_algorithm, ) def batch_size(self): @@ -1103,6 +1115,9 @@ class ScheduleBatch: self.has_stream |= other.has_stream self.has_grammar |= other.has_grammar + if self.spec_info: + self.spec_info.merge_batch(other.spec_info) + def get_model_worker_batch(self): if self.forward_mode.is_decode() or self.forward_mode.is_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None @@ -1144,6 +1159,8 @@ class ScheduleBatch: lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, input_embeds=self.input_embeds, + spec_algorithm=self.spec_algorithm, + spec_info=self.spec_info, ) def copy(self): @@ -1214,6 +1231,10 @@ class ModelWorkerBatch: # The input Embeds input_embeds: Optional[torch.tensor] = None + # Speculative decoding + spec_info: Optional[SpecInfo] = None + spec_algorithm: Optional[SpeculativeAlgorithm] = None + @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ce284c04a..c8e14a746 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -150,12 +150,18 @@ class TpModelWorker: self, model_worker_batch: ModelWorkerBatch, launch_done: Optional[threading.Event] = None, + skip_sample: bool = False, ): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) if launch_done: launch_done.set() - next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + + if skip_sample: + next_token_ids = None + else: + next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + return logits_output, next_token_ids def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c4e717ba2..d04560581 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -375,9 +375,7 @@ class CudaGraphRunner: def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None raw_bs = forward_batch.batch_size - # In normal decoding case, raw_bs == raw_num_token - # But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs - raw_num_token = forward_batch.input_ids.numel() + raw_num_token = raw_bs * self.num_tokens_per_bs # Pad if self.enable_dp_attention: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 140b0fe7c..2b5ee0919 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -96,7 +96,11 @@ class ForwardMode(IntEnum): return self == ForwardMode.DRAFT_EXTEND def is_cuda_graph(self): - return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY + return ( + self == ForwardMode.DECODE + or self == ForwardMode.TARGET_VERIFY + or self == ForwardMode.IDLE + ) def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST @@ -161,15 +165,15 @@ class ForwardBatch: token_to_kv_pool: BaseTokenToKVPool = None attn_backend: AttentionBackend = None - # Speculative decoding - spec_info: SpecInfo = None - spec_algorithm: SpeculativeAlgorithm = None - # For DP attention global_num_tokens: Optional[List[int]] = None gathered_buffer: Optional[torch.Tensor] = None can_run_dp_cuda_graph: bool = False + # Speculative decoding + spec_info: SpecInfo = None + spec_algorithm: SpeculativeAlgorithm = None + # For Qwen2-VL mrope_positions: torch.Tensor = None @@ -258,6 +262,8 @@ class ForwardBatch: can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + spec_algorithm=batch.spec_algorithm, + spec_info=batch.spec_info, input_embeds=batch.input_embeds, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 58a6a6a82..3a0c102f5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -108,14 +108,6 @@ class ServerArgs: # Model override args in JSON json_model_override_args: str = "{}" - # Double Sparsity - enable_double_sparsity: bool = False - ds_channel_config_path: str = None - ds_heavy_channel_num: int = 32 - ds_heavy_token_num: int = 256 - ds_heavy_channel_type: str = "qk" - ds_sparse_decode_threshold: int = 4096 - # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 @@ -125,6 +117,21 @@ class ServerArgs: sampling_backend: Optional[str] = None grammar_backend: Optional[str] = "outlines" + # Speculative decoding + speculative_draft_model_path: Optional[str] = None + speculative_algorithm: Optional[str] = None + speculative_num_steps: int = 5 + speculative_num_draft_tokens: int = 64 + speculative_eagle_topk: int = 8 + + # Double Sparsity + enable_double_sparsity: bool = False + ds_channel_config_path: str = None + ds_heavy_channel_num: int = 32 + ds_heavy_token_num: int = 256 + ds_heavy_channel_type: str = "qk" + ds_sparse_decode_threshold: int = 4096 + # Optimization/debug options disable_radix_cache: bool = False disable_jump_forward: bool = False @@ -602,43 +609,6 @@ class ServerArgs: default=ServerArgs.json_model_override_args, ) - # Double Sparsity - parser.add_argument( - "--enable-double-sparsity", - action="store_true", - help="Enable double sparsity attention", - ) - parser.add_argument( - "--ds-channel-config-path", - type=str, - default=ServerArgs.ds_channel_config_path, - help="The path of the double sparsity channel config", - ) - parser.add_argument( - "--ds-heavy-channel-num", - type=int, - default=ServerArgs.ds_heavy_channel_num, - help="The number of heavy channels in double sparsity attention", - ) - parser.add_argument( - "--ds-heavy-token-num", - type=int, - default=ServerArgs.ds_heavy_token_num, - help="The number of heavy tokens in double sparsity attention", - ) - parser.add_argument( - "--ds-heavy-channel-type", - type=str, - default=ServerArgs.ds_heavy_channel_type, - help="The type of heavy channels in double sparsity attention", - ) - parser.add_argument( - "--ds-sparse-decode-threshold", - type=int, - default=ServerArgs.ds_sparse_decode_threshold, - help="The type of heavy channels in double sparsity attention", - ) - # LoRA parser.add_argument( "--lora-paths", @@ -678,6 +648,75 @@ class ServerArgs: help="Choose the backend for grammar-guided decoding.", ) + # Speculative decoding + parser.add_argument( + "--speculative-algorithm", + type=str, + choices=["EAGLE"], + help="Speculative algorithm.", + ) + parser.add_argument( + "--speculative-draft-model-path", + type=str, + help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--speculative-num-steps", + type=int, + help="The number of steps sampled from draft model in Speculative Decoding.", + default=ServerArgs.speculative_num_steps, + ) + parser.add_argument( + "--speculative-num-draft-tokens", + type=int, + help="The number of token sampled from draft model in Speculative Decoding.", + default=ServerArgs.speculative_num_draft_tokens, + ) + parser.add_argument( + "--speculative-eagle-topk", + type=int, + help="The number of token sampled from draft model in eagle2 each step.", + choices=[1, 2, 4, 8], + default=ServerArgs.speculative_eagle_topk, + ) + + # Double Sparsity + parser.add_argument( + "--enable-double-sparsity", + action="store_true", + help="Enable double sparsity attention", + ) + parser.add_argument( + "--ds-channel-config-path", + type=str, + default=ServerArgs.ds_channel_config_path, + help="The path of the double sparsity channel config", + ) + parser.add_argument( + "--ds-heavy-channel-num", + type=int, + default=ServerArgs.ds_heavy_channel_num, + help="The number of heavy channels in double sparsity attention", + ) + parser.add_argument( + "--ds-heavy-token-num", + type=int, + default=ServerArgs.ds_heavy_token_num, + help="The number of heavy tokens in double sparsity attention", + ) + parser.add_argument( + "--ds-heavy-channel-type", + type=str, + default=ServerArgs.ds_heavy_channel_type, + help="The type of heavy channels in double sparsity attention", + ) + parser.add_argument( + "--ds-sparse-decode-threshold", + type=int, + default=ServerArgs.ds_sparse_decode_threshold, + help="The type of heavy channels in double sparsity attention", + ) + # Optimization/debug options parser.add_argument( "--disable-radix-cache",