[overlap-spec] Make plan stream an option (#11724)
This commit is contained in:
@@ -221,6 +221,9 @@ class Envs:
|
||||
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||
|
||||
# Overlap Spec V2
|
||||
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
||||
|
||||
# VLM
|
||||
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
|
||||
SGLANG_RESIZE_RESAMPLE = EnvStr("")
|
||||
|
||||
@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.extend_prefix_lens.sum().item(),
|
||||
sum(forward_batch.extend_prefix_lens_cpu),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -404,6 +404,8 @@ class ForwardBatch:
|
||||
if ret.positions is None:
|
||||
ret.positions = clamp_position(batch.seq_lens)
|
||||
else:
|
||||
assert isinstance(batch.extend_seq_lens, list)
|
||||
assert isinstance(batch.extend_prefix_lens, list)
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
|
||||
num_draft_tokens: int,
|
||||
draft_model_runner: Any,
|
||||
):
|
||||
seq_lens_cpu_backup = batch.seq_lens_cpu
|
||||
seq_lens_cpu_ = batch.seq_lens_cpu
|
||||
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
|
||||
|
||||
batch.spec_info = self
|
||||
@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
|
||||
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
|
||||
batch.seq_lens_sum += extend_num_tokens
|
||||
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
|
||||
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist()
|
||||
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
|
||||
batch.extend_prefix_lens = seq_lens_cpu_.tolist()
|
||||
batch.extend_num_tokens = extend_num_tokens
|
||||
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.cuda import Stream as CudaStream
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
||||
)
|
||||
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
# TODO(lsyin): potential bugs with a separate plan stream
|
||||
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
||||
|
||||
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
|
||||
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
||||
else:
|
||||
self.plan_stream = None
|
||||
self.plan_stream_ctx = contextlib.nullcontext()
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
if model_worker_batch.forward_mode.is_decode():
|
||||
@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
batch: ModelWorkerBatch,
|
||||
pre_draft_allocate_lens: torch.Tensor,
|
||||
):
|
||||
# Since batch.seq_lens is allocated in another stream, we need
|
||||
# record_stream() to prevent pytorch gc and reuse the gpu memory
|
||||
# while forward_stream is still running.
|
||||
batch.seq_lens.record_stream(torch.cuda.current_stream())
|
||||
|
||||
# Parse args
|
||||
verify_input: EagleVerifyInput = batch.spec_info
|
||||
seq_lens_backup = batch.seq_lens
|
||||
bs = len(batch.seq_lens)
|
||||
|
||||
# Batch 1: Target verify
|
||||
@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
accept_length,
|
||||
accept_index,
|
||||
) = verify_input.sample(batch, logits_output)
|
||||
new_seq_lens = seq_lens_backup + accept_length
|
||||
new_seq_lens = batch.seq_lens + accept_length
|
||||
verify_done = torch.cuda.Event()
|
||||
|
||||
# Move the accepted tokens to the target KV cache locations
|
||||
batch.seq_lens = seq_lens_backup
|
||||
self.move_accepted_tokens_to_target_kvcache(
|
||||
batch,
|
||||
accept_index,
|
||||
accept_length,
|
||||
)
|
||||
|
||||
verify_done.record()
|
||||
|
||||
all_verified_id = predict[accept_index]
|
||||
@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
ret_hidden_states = draft_logits_output.hidden_states
|
||||
|
||||
# Since seq_lens_backup's tensor is allocated in another stream, we
|
||||
# need record_stream() to prevent pytorch gc and reuse the gpu memory
|
||||
# while forward_stream is still running.
|
||||
seq_lens_backup.record_stream(torch.cuda.current_stream())
|
||||
|
||||
# Construct the return values
|
||||
next_draft_input = EagleDraftInput(
|
||||
topk_p=ret_topk_p,
|
||||
|
||||
Reference in New Issue
Block a user