diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 399125805..113e49f6c 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -221,6 +221,9 @@ class Envs: SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) + # Overlap Spec V2 + SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) + # VLM SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_RESIZE_RESAMPLE = EnvStr("") diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 97b869473..731758286 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend): ) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - forward_batch.extend_prefix_lens.sum().item(), + sum(forward_batch.extend_prefix_lens_cpu), dtype=torch.int64, device=self.device, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 0243fe89e..f34f36d70 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -404,6 +404,8 @@ class ForwardBatch: if ret.positions is None: ret.positions = clamp_position(batch.seq_lens) else: + assert isinstance(batch.extend_seq_lens, list) + assert isinstance(batch.extend_prefix_lens, list) ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index bf450934e..339576965 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin: num_draft_tokens: int, draft_model_runner: Any, ): - seq_lens_cpu_backup = batch.seq_lens_cpu + seq_lens_cpu_ = batch.seq_lens_cpu extend_num_tokens = len(batch.seq_lens) * num_draft_tokens batch.spec_info = self @@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin: batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens batch.seq_lens_sum += extend_num_tokens batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))] - batch.extend_prefix_lens = seq_lens_cpu_backup.tolist() - batch.extend_prefix_lens_cpu = seq_lens_cpu_backup + batch.extend_prefix_lens = seq_lens_cpu_.tolist() batch.extend_num_tokens = extend_num_tokens batch.capture_hidden_mode = CaptureHiddenMode.FULL batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2 diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 1b67b0e96..0104a370d 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -1,9 +1,11 @@ +import contextlib import logging from typing import List, Optional import torch from torch.cuda import Stream as CudaStream +from sglang.srt.environ import envs from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req from sglang.srt.managers.scheduler import GenerationBatchResult @@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker): self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens ) self.tree_mask_mode = TreeMaskMode.FULL_MASK - self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream() - # TODO(lsyin): potential bugs with a separate plan stream - self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) + + if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): + self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream() + self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) + else: + self.plan_stream = None + self.plan_stream_ctx = contextlib.nullcontext() def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): if model_worker_batch.forward_mode.is_decode(): @@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker): batch: ModelWorkerBatch, pre_draft_allocate_lens: torch.Tensor, ): + # Since batch.seq_lens is allocated in another stream, we need + # record_stream() to prevent pytorch gc and reuse the gpu memory + # while forward_stream is still running. + batch.seq_lens.record_stream(torch.cuda.current_stream()) + # Parse args verify_input: EagleVerifyInput = batch.spec_info - seq_lens_backup = batch.seq_lens bs = len(batch.seq_lens) # Batch 1: Target verify @@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker): accept_length, accept_index, ) = verify_input.sample(batch, logits_output) - new_seq_lens = seq_lens_backup + accept_length + new_seq_lens = batch.seq_lens + accept_length verify_done = torch.cuda.Event() - - # Move the accepted tokens to the target KV cache locations - batch.seq_lens = seq_lens_backup - self.move_accepted_tokens_to_target_kvcache( - batch, - accept_index, - accept_length, - ) - verify_done.record() all_verified_id = predict[accept_index] @@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker): ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_hidden_states = draft_logits_output.hidden_states - # Since seq_lens_backup's tensor is allocated in another stream, we - # need record_stream() to prevent pytorch gc and reuse the gpu memory - # while forward_stream is still running. - seq_lens_backup.record_stream(torch.cuda.current_stream()) - # Construct the return values next_draft_input = EagleDraftInput( topk_p=ret_topk_p,