diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index 9c61af88f..146346a75 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -61,10 +61,10 @@ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is ` ```bash # node 1 -python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code +python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code # node 2 -python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code +python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code ``` If you have two H100 nodes, the usage is similar to the aforementioned H20. diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 8ab3e4524..63787addf 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers @@ -214,6 +215,7 @@ def extend(reqs, model_runner): tree_cache=None, model_config=model_runner.model_config, enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, ) batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 2543e9de6..140755ff5 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -26,7 +26,7 @@ class AttentionBackend(ABC): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_token: int, + num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 1cd5c56cf..8b823cc5a 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -227,7 +227,7 @@ class FlashInferAttnBackend(AttentionBackend): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_token: int, + num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], @@ -243,9 +243,11 @@ class FlashInferAttnBackend(AttentionBackend): "NHD", use_cuda_graph=True, use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1], + paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token], + paged_kv_last_page_len_buffer=self.kv_last_page_len[ + :num_tokens + ], ) ) seq_lens_sum = seq_lens.sum().item() diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a2bc50478..04327b162 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_token: int, + num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 64cc9de10..2a5db9084 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -575,8 +575,8 @@ class ScheduleBatch: device: str = "cuda" # Speculative decoding + spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[SpecInfo] = None - spec_algorithm: Optional[SpeculativeAlgorithm] = None @classmethod def init_new( @@ -587,7 +587,7 @@ class ScheduleBatch: tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, - speculative_algorithm: Optional[SpeculativeAlgorithm] = None, + spec_algorithm: SpeculativeAlgorithm, ): return cls( reqs=reqs, @@ -600,7 +600,7 @@ class ScheduleBatch: has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, - spec_algorithm=speculative_algorithm, + spec_algorithm=spec_algorithm, ) def batch_size(self): @@ -1010,6 +1010,8 @@ class ScheduleBatch: def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE + if self.spec_algorithm.is_eagle(): + return self.input_ids = self.output_ids self.output_ids = None @@ -1172,6 +1174,7 @@ class ScheduleBatch: out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, + spec_algorithm=self.spec_algorithm, ) def __str__(self): @@ -1232,8 +1235,8 @@ class ModelWorkerBatch: input_embeds: Optional[torch.tensor] = None # Speculative decoding + spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[SpecInfo] = None - spec_algorithm: Optional[SpeculativeAlgorithm] = None @triton.jit diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4bf41aaf3..0d51c695a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( broadcast_pyobj, configure_logger, @@ -116,6 +117,14 @@ class Scheduler: self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics + self.spec_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) + self.decode_mem_cache_buf_multiplier = ( + self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_none() + else 1 + ) # Init inter-process communication context = zmq.Context(2) @@ -199,6 +208,21 @@ class Scheduler: nccl_port=port_args.nccl_port, ) + # Launch worker for speculative decoding if need + if self.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + self.draft_worker = EAGLEWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + else: + self.draft_worker = None + # Get token and memory info from the model worker ( self.max_total_num_tokens, @@ -855,6 +879,7 @@ class Scheduler: self.tree_cache, self.model_config, self.enable_overlap, + self.spec_algorithm, ) new_batch.prepare_for_extend() @@ -888,11 +913,15 @@ class Scheduler: return None # Check if decode out of memory - if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): + if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( + test_retract and batch.batch_size() > 10 + ): old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio = batch.retract_decode() self.new_token_ratio = new_token_ratio + if self.draft_worker: + self.draft_worker.finish_request(retracted_reqs) logger.info( "Decode out of memory happened. " @@ -926,11 +955,17 @@ class Scheduler: self.forward_ct += 1 if self.is_generation: - model_worker_batch = batch.get_model_worker_batch() if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - model_worker_batch - ) + if self.spec_algorithm.is_none(): + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = ( + self.tp_worker.forward_batch_generation(model_worker_batch) + ) + else: + logits_output, next_token_ids, model_worker_batch, spec_info = ( + self.draft_worker.forward_batch_speculative_generation(batch) + ) + batch.spec_info = spec_info elif batch.forward_mode.is_idle(): model_worker_batch = batch.get_model_worker_batch() self.tp_worker.forward_batch_idle(model_worker_batch) @@ -1077,7 +1112,10 @@ class Scheduler: self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue - req.output_ids.append(next_token_id) + if batch.spec_algorithm.is_none(): + # speculative worker will solve the output_ids in speculative decoding + req.output_ids.append(next_token_id) + req.check_finished() if req.finished(): @@ -1252,6 +1290,9 @@ class Scheduler: # If not stream, we still want to output some tokens to get the benefit of incremental decoding. or (not req.stream and len(req.output_ids) % 50 == 0) ): + if self.draft_worker and req.finished(): + self.draft_worker.finish_request(req) + rids.append(req.rid) finished_reasons.append( req.finished_reason.to_json() if req.finished_reason else None @@ -1383,6 +1424,7 @@ class Scheduler: self.tree_cache, self.model_config, self.enable_overlap, + self.spec_algorithm, ) idle_batch.prepare_for_idle() return idle_batch diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 6168441d1..25a1c85f2 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -45,13 +45,18 @@ class TpModelWorker: tp_rank: int, dp_rank: Optional[int], nccl_port: int, + is_draft_worker: bool = False, ): # Parse args self.tp_rank = tp_rank # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, + ( + server_args.model_path + if not is_draft_worker + else server_args.speculative_draft_model_path + ), trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, context_length=server_args.context_length, @@ -68,6 +73,7 @@ class TpModelWorker: tp_size=server_args.tp_size, nccl_port=nccl_port, server_args=server_args, + is_draft_worker=is_draft_worker, ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d04560581..1689f7d66 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather +from sglang.srt.utils import monkey_patch_vllm_all_gather if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -106,11 +106,6 @@ def set_torch_compile_config(): torch._dynamo.config.cache_size_limit = 1024 -@maybe_torch_compile(dynamic=True) -def clamp_position(seq_lens): - return torch.clamp((seq_lens - 1), min=0).to(torch.int64) - - class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" @@ -157,6 +152,17 @@ class CudaGraphRunner: self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 + if model_runner.spec_algorithm.is_eagle(): + if self.model_runner.is_draft_worker: + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_eagle_topk + ) + else: + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) + self.compile_bs = ( [ bs @@ -192,6 +198,13 @@ class CudaGraphRunner: self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) + # Speculative_inference + if model_runner.spec_algorithm.is_eagle(): + self.hidden_states = torch.zeros( + (self.max_num_token, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) + if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch self.encoder_lens = torch.full( @@ -234,9 +247,6 @@ class CudaGraphRunner: self.model_runner.model.capture_mode = False def can_run(self, forward_batch: ForwardBatch): - if not forward_batch.forward_mode.is_cuda_graph(): - return False - if self.enable_dp_attention: min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( forward_batch.global_num_tokens @@ -291,21 +301,18 @@ class CudaGraphRunner: def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream - num_token = bs * self.num_tokens_per_bs + num_tokens = bs * self.num_tokens_per_bs # Common inputs - input_ids = self.input_ids[:num_token] + input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:num_token] - positions = self.positions[:num_token] - + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] if self.is_encoder_decoder: encoder_lens = self.encoder_lens[:bs] else: encoder_lens = None - - seq_lens_sum = seq_lens.sum().item() mrope_positions = self.mrope_positions[:, :bs] if self.enable_dp_attention: @@ -325,20 +332,22 @@ class CudaGraphRunner: token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens_sum, + seq_lens_sum=seq_lens.sum(), encoder_lens=encoder_lens, return_logprob=False, - top_logprobs_nums=[0] * num_token, + top_logprobs_nums=[0] * bs, positions=positions, global_num_tokens=global_num_tokens, mrope_positions=mrope_positions, gathered_buffer=gathered_buffer, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=self.get_spec_info(num_tokens, positions), ) # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, - num_token, + num_tokens, req_pool_indices, seq_lens, encoder_lens, @@ -394,14 +403,16 @@ class CudaGraphRunner: self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) - positions = clamp_position(forward_batch.seq_lens) - self.positions[:raw_num_token].copy_(positions) + self.positions[:raw_num_token].copy_(forward_batch.positions) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + if hasattr(forward_batch.spec_info, "hidden_states"): + self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states + # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, @@ -424,3 +435,36 @@ class CudaGraphRunner: ), ) return logits_output + + def get_spec_info(self, num_tokens: int, positions: torch.Tensor): + spec_info = None + if self.model_runner.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_utils import ( + EAGLEDraftInput, + EagleVerifyInput, + ) + + if self.model_runner.is_draft_worker: + spec_info = EAGLEDraftInput() + spec_info.hidden_states = self.hidden_states[:num_tokens] + spec_info.positions = positions + spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + spec_info.init(self.model_runner.server_args) + else: + spec_info = EagleVerifyInput( + None, + None, + None, + None, + None, + None, + self.model_runner.server_args.speculative_num_draft_tokens, + ) + spec_info.custom_mask = torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ) + spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + + return spec_info diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 2b5ee0919..926961149 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,6 +38,7 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.utils import maybe_torch_compile if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend @@ -276,10 +277,21 @@ class ForwardBatch: ) if ret.forward_mode.is_idle(): + ret.positions = torch.empty((0,), device=device) return ret + # Override the positions with spec_info + if ( + ret.spec_info is not None + and getattr(ret.spec_info, "positions", None) is not None + ): + ret.positions = ret.spec_info.positions + # Init position information - if not ret.forward_mode.is_decode(): + if ret.forward_mode.is_decode(): + if ret.positions is None: + ret.positions = clamp_position(batch.seq_lens) + else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) @@ -288,13 +300,15 @@ class ForwardBatch: ).to(device, non_blocking=True) if model_runner.server_args.attention_backend != "torch_native": ret.extend_num_tokens = batch.extend_num_tokens - ret.positions, ret.extend_start_loc = compute_position_triton( + positions, ret.extend_start_loc = compute_position_triton( ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens ) else: - ret.positions, ret.extend_start_loc = compute_position_torch( + positions, ret.extend_start_loc = compute_position_torch( ret.extend_prefix_lens, ret.extend_seq_lens ) + if ret.positions is None: + ret.positions = positions ret.extend_prefix_lens_cpu = batch.extend_prefix_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens @@ -383,6 +397,11 @@ def compute_position_torch( return positions.to(torch.int64), extend_start_loc +@maybe_torch_compile(dynamic=True) +def clamp_position(seq_lens): + return torch.clamp((seq_lens - 1), min=0).to(torch.int64) + + class CaptureHiddenMode(IntEnum): NULL = auto() FULL = auto() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 53b48ce78..41905e272 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, @@ -74,6 +75,7 @@ class ModelRunner: tp_size: int, nccl_port: int, server_args: ServerArgs, + is_draft_worker: bool = False, ): # Parse args self.model_config = model_config @@ -84,8 +86,12 @@ class ModelRunner: self.tp_size = tp_size self.dist_port = nccl_port self.server_args = server_args + self.is_draft_worker = is_draft_worker self.is_generation = model_config.is_generation self.is_multimodal = model_config.is_multimodal + self.spec_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) # Model-specific adjustment if ( @@ -205,14 +211,18 @@ class ModelRunner: else: dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) - init_distributed_environment( - backend=backend, - world_size=self.tp_size, - rank=self.tp_rank, - local_rank=self.gpu_id, - distributed_init_method=dist_init_method, - ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + + if not self.is_draft_worker: + # Only initilzie the distributed environment on the target model worker. + init_distributed_environment( + backend=backend, + world_size=self.tp_size, + rank=self.tp_rank, + local_rank=self.gpu_id, + distributed_init_method=dist_init_method, + ) + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) @@ -407,7 +417,6 @@ class ModelRunner: target_dtype = ( dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) ) - current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype assert ( self._model_update_group is not None @@ -506,6 +515,28 @@ class ModelRunner: ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + + if max_num_reqs is None: + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 4096, + ) + + if not self.spec_algorithm.is_none(): + if self.is_draft_worker: + self.max_total_num_tokens = self.server_args.draft_runner_cache_size + else: + self.server_args.draft_runner_cache_size = ( + self.max_total_num_tokens + + max_num_reqs * self.server_args.speculative_num_steps + + 100 + ) + if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: logging.warning( @@ -520,17 +551,6 @@ class ModelRunner: "Not enough memory. Please try to increase --mem-fraction-static." ) - if max_num_reqs is None: - max_num_reqs = min( - max( - int( - self.max_total_num_tokens / self.model_config.context_len * 512 - ), - 2048, - ), - 4096, - ) - self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, @@ -650,10 +670,6 @@ class ModelRunner: tensor_parallel(self.model, device_mesh) def forward_decode(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): - return self.cuda_graph_runner.replay(forward_batch) - - forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) self.attn_backend.init_forward_metadata(forward_batch) return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch @@ -683,14 +699,18 @@ class ModelRunner: ) def forward_idle(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): - return self.cuda_graph_runner.replay(forward_batch) - return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: + if ( + forward_batch.forward_mode.is_cuda_graph() + and self.cuda_graph_runner + and self.cuda_graph_runner.can_run(forward_batch) + ): + return self.cuda_graph_runner.replay(forward_batch) + if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3a0c102f5..b61b4b2dc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -23,6 +23,7 @@ from typing import List, Optional import torch from sglang.srt.hf_transformers_utils import check_gguf_file +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_hpu_memory_capacity, @@ -247,6 +248,17 @@ class ServerArgs: "Overlap scheduler is disabled." ) + # Speculative Decoding + if self.speculative_algorithm == "EAGLE": + self.prefill_only_one_req = True + self.disable_cuda_graph_padding = True + self.disable_radix_cache = True + self.disable_overlap_schedule = True + self.chunked_prefill_size = -1 + logger.info( + "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding." + ) + # GGUF if ( self.load_format == "auto" or self.load_format == "gguf" diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 3b306e985..5f156b837 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -2,8 +2,12 @@ from enum import IntEnum, auto class SpeculativeAlgorithm(IntEnum): + NONE = auto() EAGLE = auto() + def is_none(self): + return self == SpeculativeAlgorithm.NONE + def is_eagle(self): return self == SpeculativeAlgorithm.EAGLE @@ -11,6 +15,7 @@ class SpeculativeAlgorithm(IntEnum): def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, + None: SpeculativeAlgorithm.NONE, } return name_map[name]