[overlap-spec] Make plan stream an option (#11724)
This commit is contained in:
@@ -221,6 +221,9 @@ class Envs:
|
|||||||
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||||
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||||
|
|
||||||
|
# Overlap Spec V2
|
||||||
|
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
||||||
|
|
||||||
# VLM
|
# VLM
|
||||||
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
|
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
|
||||||
SGLANG_RESIZE_RESAMPLE = EnvStr("")
|
SGLANG_RESIZE_RESAMPLE = EnvStr("")
|
||||||
|
|||||||
@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
forward_batch.extend_prefix_lens.sum().item(),
|
sum(forward_batch.extend_prefix_lens_cpu),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -404,6 +404,8 @@ class ForwardBatch:
|
|||||||
if ret.positions is None:
|
if ret.positions is None:
|
||||||
ret.positions = clamp_position(batch.seq_lens)
|
ret.positions = clamp_position(batch.seq_lens)
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(batch.extend_seq_lens, list)
|
||||||
|
assert isinstance(batch.extend_prefix_lens, list)
|
||||||
ret.extend_seq_lens = torch.tensor(
|
ret.extend_seq_lens = torch.tensor(
|
||||||
batch.extend_seq_lens, dtype=torch.int32
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
|
|||||||
num_draft_tokens: int,
|
num_draft_tokens: int,
|
||||||
draft_model_runner: Any,
|
draft_model_runner: Any,
|
||||||
):
|
):
|
||||||
seq_lens_cpu_backup = batch.seq_lens_cpu
|
seq_lens_cpu_ = batch.seq_lens_cpu
|
||||||
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
|
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
|
||||||
|
|
||||||
batch.spec_info = self
|
batch.spec_info = self
|
||||||
@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
|
|||||||
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
|
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
|
||||||
batch.seq_lens_sum += extend_num_tokens
|
batch.seq_lens_sum += extend_num_tokens
|
||||||
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
|
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
|
||||||
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist()
|
batch.extend_prefix_lens = seq_lens_cpu_.tolist()
|
||||||
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
|
|
||||||
batch.extend_num_tokens = extend_num_tokens
|
batch.extend_num_tokens = extend_num_tokens
|
||||||
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda import Stream as CudaStream
|
from torch.cuda import Stream as CudaStream
|
||||||
|
|
||||||
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
|
||||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||||
@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||||
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
|
||||||
# TODO(lsyin): potential bugs with a separate plan stream
|
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
|
||||||
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||||
|
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
||||||
|
else:
|
||||||
|
self.plan_stream = None
|
||||||
|
self.plan_stream_ctx = contextlib.nullcontext()
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
if model_worker_batch.forward_mode.is_decode():
|
if model_worker_batch.forward_mode.is_decode():
|
||||||
@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
batch: ModelWorkerBatch,
|
batch: ModelWorkerBatch,
|
||||||
pre_draft_allocate_lens: torch.Tensor,
|
pre_draft_allocate_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
# Since batch.seq_lens is allocated in another stream, we need
|
||||||
|
# record_stream() to prevent pytorch gc and reuse the gpu memory
|
||||||
|
# while forward_stream is still running.
|
||||||
|
batch.seq_lens.record_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
verify_input: EagleVerifyInput = batch.spec_info
|
verify_input: EagleVerifyInput = batch.spec_info
|
||||||
seq_lens_backup = batch.seq_lens
|
|
||||||
bs = len(batch.seq_lens)
|
bs = len(batch.seq_lens)
|
||||||
|
|
||||||
# Batch 1: Target verify
|
# Batch 1: Target verify
|
||||||
@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
accept_length,
|
accept_length,
|
||||||
accept_index,
|
accept_index,
|
||||||
) = verify_input.sample(batch, logits_output)
|
) = verify_input.sample(batch, logits_output)
|
||||||
new_seq_lens = seq_lens_backup + accept_length
|
new_seq_lens = batch.seq_lens + accept_length
|
||||||
verify_done = torch.cuda.Event()
|
verify_done = torch.cuda.Event()
|
||||||
|
|
||||||
# Move the accepted tokens to the target KV cache locations
|
|
||||||
batch.seq_lens = seq_lens_backup
|
|
||||||
self.move_accepted_tokens_to_target_kvcache(
|
|
||||||
batch,
|
|
||||||
accept_index,
|
|
||||||
accept_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
verify_done.record()
|
verify_done.record()
|
||||||
|
|
||||||
all_verified_id = predict[accept_index]
|
all_verified_id = predict[accept_index]
|
||||||
@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
ret_hidden_states = draft_logits_output.hidden_states
|
ret_hidden_states = draft_logits_output.hidden_states
|
||||||
|
|
||||||
# Since seq_lens_backup's tensor is allocated in another stream, we
|
|
||||||
# need record_stream() to prevent pytorch gc and reuse the gpu memory
|
|
||||||
# while forward_stream is still running.
|
|
||||||
seq_lens_backup.record_stream(torch.cuda.current_stream())
|
|
||||||
|
|
||||||
# Construct the return values
|
# Construct the return values
|
||||||
next_draft_input = EagleDraftInput(
|
next_draft_input = EagleDraftInput(
|
||||||
topk_p=ret_topk_p,
|
topk_p=ret_topk_p,
|
||||||
|
|||||||
Reference in New Issue
Block a user