[overlap-spec] Make plan stream an option (#11724)

This commit is contained in:
Liangsheng Yin
2025-10-17 15:48:57 +08:00
committed by GitHub
parent ce11dd82dc
commit d88ac9bc9a
5 changed files with 23 additions and 23 deletions

View File

@@ -221,6 +221,9 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
# Overlap Spec V2
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
# VLM # VLM
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("") SGLANG_RESIZE_RESAMPLE = EnvStr("")

View File

@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
) )
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(), sum(forward_batch.extend_prefix_lens_cpu),
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )

View File

@@ -404,6 +404,8 @@ class ForwardBatch:
if ret.positions is None: if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens) ret.positions = clamp_position(batch.seq_lens)
else: else:
assert isinstance(batch.extend_seq_lens, list)
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)

View File

@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
num_draft_tokens: int, num_draft_tokens: int,
draft_model_runner: Any, draft_model_runner: Any,
): ):
seq_lens_cpu_backup = batch.seq_lens_cpu seq_lens_cpu_ = batch.seq_lens_cpu
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
batch.spec_info = self batch.spec_info = self
@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
batch.seq_lens_sum += extend_num_tokens batch.seq_lens_sum += extend_num_tokens
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))] batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist() batch.extend_prefix_lens = seq_lens_cpu_.tolist()
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
batch.extend_num_tokens = extend_num_tokens batch.extend_num_tokens = extend_num_tokens
batch.capture_hidden_mode = CaptureHiddenMode.FULL batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2 batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2

View File

@@ -1,9 +1,11 @@
import contextlib
import logging import logging
from typing import List, Optional from typing import List, Optional
import torch import torch
from torch.cuda import Stream as CudaStream from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
) )
self.tree_mask_mode = TreeMaskMode.FULL_MASK self.tree_mask_mode = TreeMaskMode.FULL_MASK
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
# TODO(lsyin): potential bugs with a separate plan stream if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
else:
self.plan_stream = None
self.plan_stream_ctx = contextlib.nullcontext()
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode(): if model_worker_batch.forward_mode.is_decode():
@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
pre_draft_allocate_lens: torch.Tensor, pre_draft_allocate_lens: torch.Tensor,
): ):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
batch.seq_lens.record_stream(torch.cuda.current_stream())
# Parse args # Parse args
verify_input: EagleVerifyInput = batch.spec_info verify_input: EagleVerifyInput = batch.spec_info
seq_lens_backup = batch.seq_lens
bs = len(batch.seq_lens) bs = len(batch.seq_lens)
# Batch 1: Target verify # Batch 1: Target verify
@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
accept_length, accept_length,
accept_index, accept_index,
) = verify_input.sample(batch, logits_output) ) = verify_input.sample(batch, logits_output)
new_seq_lens = seq_lens_backup + accept_length new_seq_lens = batch.seq_lens + accept_length
verify_done = torch.cuda.Event() verify_done = torch.cuda.Event()
# Move the accepted tokens to the target KV cache locations
batch.seq_lens = seq_lens_backup
self.move_accepted_tokens_to_target_kvcache(
batch,
accept_index,
accept_length,
)
verify_done.record() verify_done.record()
all_verified_id = predict[accept_index] all_verified_id = predict[accept_index]
@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states ret_hidden_states = draft_logits_output.hidden_states
# Since seq_lens_backup's tensor is allocated in another stream, we
# need record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
seq_lens_backup.record_stream(torch.cuda.current_stream())
# Construct the return values # Construct the return values
next_draft_input = EagleDraftInput( next_draft_input = EagleDraftInput(
topk_p=ret_topk_p, topk_p=ret_topk_p,