Remove overlap thread (#11210)
Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -747,11 +747,13 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_decode(self: Scheduler):
|
||||
result_queue = deque()
|
||||
self.result_queue = deque()
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
@@ -774,13 +776,13 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch_:
|
||||
result_queue.append((batch_.copy(), result))
|
||||
self.result_queue.append((batch_.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
else:
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
@@ -798,12 +800,12 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch:
|
||||
result_queue.append((batch.copy(), result))
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and self.last_batch_in_queue:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
|
||||
@@ -321,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
@@ -368,7 +370,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
@@ -379,31 +380,30 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
next_token_ids,
|
||||
extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req,
|
||||
copy_done,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.extend_input_len_per_req,
|
||||
result.extend_logprob_start_len_per_req,
|
||||
result.copy_done,
|
||||
)
|
||||
|
||||
if copy_done is not None:
|
||||
copy_done.synchronize()
|
||||
|
||||
logprob_pt = 0
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
||||
launch_done
|
||||
)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
|
||||
hidden_state_offset = 0
|
||||
for i, (req, next_token_id) in enumerate(
|
||||
|
||||
@@ -37,8 +37,7 @@ class FutureMap:
|
||||
return cur_future_ct
|
||||
|
||||
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
||||
input_ids = model_worker_batch.input_ids
|
||||
_resolve_future_token_ids(input_ids, self.token_ids_buf)
|
||||
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
|
||||
|
||||
def update_next_future(self, future_ct: int, bs: int):
|
||||
return torch.arange(
|
||||
|
||||
@@ -886,9 +886,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# This is an optimization to reduce the overhead of the prefill check.
|
||||
batch_is_full: bool = False
|
||||
|
||||
# Events
|
||||
launch_done: Optional[threading.Event] = None
|
||||
|
||||
# For chunked prefill in PP
|
||||
chunked_req: Optional[Req] = None
|
||||
|
||||
@@ -1877,7 +1874,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
),
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
launch_done=self.launch_done,
|
||||
is_prefill_only=self.is_prefill_only,
|
||||
)
|
||||
|
||||
@@ -2018,8 +2014,8 @@ class ModelWorkerBatch:
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
hicache_consumer_index: int = -1
|
||||
|
||||
# Overlap event
|
||||
launch_done: Optional[threading.Event] = None
|
||||
# Overlap scheduler related
|
||||
delay_sample_launch: bool = False
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
@@ -25,12 +25,14 @@ from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Deque, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
import torch
|
||||
import zmq
|
||||
from torch.cuda import Stream as CudaStream
|
||||
from torch.cuda import StreamContext as CudaStreamContext
|
||||
from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
from sglang.srt.managers.overlap_utils import FutureMap
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ModelWorkerBatch,
|
||||
MultimodalInputs,
|
||||
Req,
|
||||
RequestStage,
|
||||
@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||
SchedulerUpdateWeightsMixin,
|
||||
)
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatchOutput,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: Optional[LogitsProcessorOutput]
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
||||
next_token_ids: Optional[List[int]]
|
||||
can_run_cuda_graph: bool
|
||||
logits_output: Optional[LogitsProcessorOutput] = None
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
||||
next_token_ids: Optional[torch.Tensor] = None
|
||||
num_accepted_tokens: Optional[int] = None
|
||||
can_run_cuda_graph: bool = False
|
||||
|
||||
# For output processing
|
||||
extend_input_len_per_req: List[int]
|
||||
extend_logprob_start_len_per_req: List[int]
|
||||
extend_input_len_per_req: Optional[List[int]] = None
|
||||
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch_output(
|
||||
cls,
|
||||
forward_batch_output: ForwardBatchOutput,
|
||||
extend_input_len_per_req: List[int],
|
||||
extend_logprob_start_len_per_req: List[int],
|
||||
):
|
||||
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
||||
# For overlap scheduling
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
delay_sample_launch: bool = False
|
||||
forward_batch: Optional[ForwardBatch] = None
|
||||
future_map_ct: Optional[int] = None
|
||||
|
||||
return cls(
|
||||
logits_output=forward_batch_output.logits_output,
|
||||
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
||||
next_token_ids=forward_batch_output.next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
Only the tensors which are needed for processing results are copied,
|
||||
e.g., next_token_ids, logits outputs
|
||||
"""
|
||||
if return_logprob:
|
||||
if self.logits_output.next_token_logits is not None:
|
||||
self.logits_output.next_token_logits = (
|
||||
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.input_token_logprobs is not None:
|
||||
self.logits_output.input_token_logprobs = (
|
||||
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.hidden_states is not None:
|
||||
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||
self.copy_done.record()
|
||||
|
||||
@classmethod
|
||||
def from_pp_proxy(
|
||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||
):
|
||||
# TODO(lsyin): also simplify this logic
|
||||
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
||||
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
||||
# TODO(lsyin): refactor PP and avoid using dict
|
||||
proxy_dict = next_pp_outputs.tensors
|
||||
return cls(
|
||||
logits_output=logits_output,
|
||||
@@ -388,12 +398,10 @@ class Scheduler(
|
||||
logger.info("Overlap scheduler is disabled for embedding models.")
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
|
||||
self.tp_worker = TpWorkerClass(
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
|
||||
self.tp_worker = TpModelWorker(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
@@ -525,9 +533,11 @@ class Scheduler(
|
||||
self.kv_transfer_speed_gb_s: float = 0.0
|
||||
self.kv_transfer_latency_ms: float = 0.0
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
self.default_stream: CudaStream = torch.get_device_module(
|
||||
self.device
|
||||
).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.current_stream.synchronize = lambda: None # No-op for CPU
|
||||
self.default_stream.synchronize = lambda: None # No-op for CPU
|
||||
self.forward_sleep_time = None
|
||||
|
||||
# Init chunked prefill
|
||||
@@ -618,6 +628,9 @@ class Scheduler(
|
||||
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
||||
self.init_deterministic_inference_config()
|
||||
|
||||
# Init overlap
|
||||
self.init_overlap()
|
||||
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -932,6 +945,32 @@ class Scheduler(
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||
|
||||
def init_overlap(self):
|
||||
if not self.enable_overlap:
|
||||
return
|
||||
|
||||
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
|
||||
self.device
|
||||
).stream(self.forward_stream)
|
||||
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
|
||||
self.device
|
||||
).stream(self.copy_stream)
|
||||
|
||||
self.future_map = FutureMap(self.max_running_requests, self.device)
|
||||
self.batch_record_buf = [None] * 2
|
||||
self.batch_record_ct = 0
|
||||
|
||||
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
|
||||
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
|
||||
# NOTE: More Reliable: record all tensors into the forward stream
|
||||
# NOTE: - for all future tensors, we shall always read from future map
|
||||
# - for all non-future tensors (produced only by schedule stream),
|
||||
# we shall keep its reference not being release during all the forwarding pass
|
||||
self.batch_record_ct = (self.batch_record_ct + 1) % 2
|
||||
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
|
||||
|
||||
def init_moe_config(self):
|
||||
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
||||
initialize_moe_config(self.server_args)
|
||||
@@ -958,9 +997,11 @@ class Scheduler(
|
||||
@DynamicGradMode()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
self.result_queue = deque()
|
||||
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
@@ -968,7 +1009,6 @@ class Scheduler(
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
batch.launch_done = threading.Event()
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
@@ -980,7 +1020,7 @@ class Scheduler(
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
@@ -988,10 +1028,7 @@ class Scheduler(
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
||||
self.process_batch_result(
|
||||
tmp_batch, tmp_result, batch.launch_done if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.self_check_during_idle()
|
||||
@@ -2056,18 +2093,62 @@ class Scheduler(
|
||||
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
forward_batch_output = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
if self.enable_overlap:
|
||||
# FIXME: remove this assert
|
||||
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
|
||||
model_worker_batch = batch_or_worker_batch
|
||||
self.record_batch_in_overlap(model_worker_batch)
|
||||
|
||||
# Sampling info will be modified during forward
|
||||
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
|
||||
model_worker_batch.sampling_info.copy_for_forward()
|
||||
)
|
||||
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
cur_future_map_ct = self.future_map.update_ct(bs)
|
||||
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
self.future_map.resolve_future(model_worker_batch)
|
||||
if batch.sampling_info.grammars is not None:
|
||||
model_worker_batch.delay_sample_launch = True
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
# FIXME(lsyin): maybe move this to forward_batch_generation
|
||||
batch_result.copy_done = torch.get_device_module(
|
||||
self.device
|
||||
).Event()
|
||||
if not model_worker_batch.delay_sample_launch:
|
||||
self.future_map.store_to_map(
|
||||
cur_future_map_ct, bs, batch_result.next_token_ids
|
||||
)
|
||||
batch_result.copy_to_cpu()
|
||||
else:
|
||||
batch_result.future_map_ct = cur_future_map_ct
|
||||
|
||||
# FIXME(lsyin): move this assignment elsewhere
|
||||
maybe_future_next_token_ids = self.future_map.update_next_future(
|
||||
cur_future_map_ct, bs
|
||||
)
|
||||
else:
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
maybe_future_next_token_ids = batch_result.next_token_ids
|
||||
copy_done = None
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
self.udpate_spec_metrics(
|
||||
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
||||
self.update_spec_metrics(
|
||||
batch.batch_size(), batch_result.num_accepted_tokens
|
||||
)
|
||||
|
||||
# update batch's output ids
|
||||
batch.output_ids = forward_batch_output.next_token_ids
|
||||
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
|
||||
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
||||
# we shall still keep the original outputs, e.g. next_token_ids
|
||||
# in the GenerationBatchOutput for processing after copy_done.
|
||||
batch.output_ids = maybe_future_next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -2084,36 +2165,60 @@ class Scheduler(
|
||||
else:
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
return GenerationBatchResult.from_forward_batch_output(
|
||||
forward_batch_output=forward_batch_output,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
batch_result.extend_input_len_per_req = extend_input_len_per_req
|
||||
batch_result.extend_logprob_start_len_per_req = (
|
||||
extend_logprob_start_len_per_req
|
||||
)
|
||||
return batch_result
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
return ret
|
||||
|
||||
def launch_last_batch_sample_if_needed(
|
||||
self,
|
||||
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
||||
if len(self.result_queue) == 0:
|
||||
return
|
||||
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
|
||||
tmp_result: GenerationBatchResult
|
||||
if not tmp_result.delay_sample_launch:
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
return
|
||||
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
|
||||
tmp_result.logits_output,
|
||||
tmp_result.forward_batch,
|
||||
)
|
||||
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
|
||||
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
|
||||
tmp_result.copy_to_cpu()
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
|
||||
def process_batch_result(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result, launch_done)
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if self.enable_trace:
|
||||
trace_slice_batch("decode loop", batch.reqs)
|
||||
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result, launch_done)
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
if self.enable_trace:
|
||||
trace_slice_batch("prefill", batch.reqs)
|
||||
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
if result.copy_done is not None:
|
||||
result.copy_done.synchronize()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
@@ -2330,7 +2435,7 @@ class Scheduler(
|
||||
if batch.next_batch_sampling_info:
|
||||
if batch.next_batch_sampling_info.grammars is not None:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
self.default_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def watchdog_thread(self):
|
||||
|
||||
@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
|
||||
kv_events_config, self.attn_dp_rank
|
||||
)
|
||||
|
||||
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
||||
def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||
self.spec_num_total_forward_ct += bs
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
|
||||
@@ -39,7 +39,6 @@ class SchedulerOutputProcessorMixin:
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
skip_stream_req = None
|
||||
|
||||
@@ -49,29 +48,29 @@ class SchedulerOutputProcessorMixin:
|
||||
next_token_ids,
|
||||
extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req,
|
||||
copy_done,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.extend_input_len_per_req,
|
||||
result.extend_logprob_start_len_per_req,
|
||||
result.copy_done,
|
||||
)
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids, _ = (
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
if copy_done is not None:
|
||||
copy_done.synchronize()
|
||||
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
|
||||
hidden_state_offset = 0
|
||||
|
||||
@@ -204,22 +203,19 @@ class SchedulerOutputProcessorMixin:
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.can_run_cuda_graph,
|
||||
result.copy_done,
|
||||
)
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
elif batch.spec_algorithm.is_none():
|
||||
# spec decoding handles output logprobs inside verify process.
|
||||
if copy_done is not None:
|
||||
copy_done.synchronize()
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
|
||||
@@ -15,14 +15,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -36,13 +34,10 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardBatchOutput,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
||||
@@ -236,9 +231,8 @@ class TpModelWorker:
|
||||
def forward_batch_generation(
|
||||
self,
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
is_verify: bool = False,
|
||||
) -> ForwardBatchOutput:
|
||||
) -> GenerationBatchResult:
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||
|
||||
@@ -256,32 +250,43 @@ class TpModelWorker:
|
||||
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
||||
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
||||
)
|
||||
if launch_done is not None:
|
||||
launch_done.set()
|
||||
|
||||
skip_sample = is_verify or model_worker_batch.is_prefill_only
|
||||
next_token_ids = None
|
||||
|
||||
if not skip_sample:
|
||||
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
||||
elif model_worker_batch.return_logprob and not is_verify:
|
||||
# NOTE: Compute logprobs without full sampling
|
||||
self.model_runner.compute_logprobs_only(
|
||||
logits_output, model_worker_batch
|
||||
)
|
||||
|
||||
return ForwardBatchOutput(
|
||||
batch_result = GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
if is_verify:
|
||||
# Skip sampling and return logits for target forward
|
||||
return batch_result
|
||||
|
||||
if model_worker_batch.delay_sample_launch:
|
||||
batch_result.delay_sample_launch = True
|
||||
batch_result.forward_batch = forward_batch
|
||||
return batch_result
|
||||
|
||||
if model_worker_batch.is_prefill_only:
|
||||
# For prefill-only requests, create dummy token IDs on CPU
|
||||
batch_result.next_token_ids = torch.zeros_like(
|
||||
model_worker_batch.input_ids, dtype=torch.long
|
||||
)
|
||||
if model_worker_batch.return_logprob:
|
||||
# NOTE: Compute logprobs without full sampling
|
||||
self.model_runner.compute_logprobs_only(
|
||||
logits_output, model_worker_batch
|
||||
)
|
||||
else:
|
||||
batch_result.next_token_ids = self.model_runner.sample(
|
||||
logits_output, forward_batch
|
||||
)
|
||||
|
||||
return batch_result
|
||||
else:
|
||||
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
||||
forward_batch,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
return ForwardBatchOutput(
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
return GenerationBatchResult(
|
||||
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
|
||||
@@ -232,12 +232,8 @@ class TpModelWorkerClient:
|
||||
self, model_worker_batch: ModelWorkerBatch
|
||||
) -> ForwardBatchOutput:
|
||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
||||
sampling_info = model_worker_batch.sampling_info
|
||||
sampling_info.update_penalties()
|
||||
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
||||
sampling_info,
|
||||
sampling_info_done=threading.Event(),
|
||||
penalizer_orchestrator=None,
|
||||
model_worker_batch.sampling_info = self.cur_sampling_info = (
|
||||
model_worker_batch.sampling_info.copy_for_forward()
|
||||
)
|
||||
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
|
||||
@@ -902,17 +902,6 @@ class ForwardBatch:
|
||||
return self.tbo_split_seq_index is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatchOutput:
|
||||
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
|
||||
# need to be more organized
|
||||
logits_output: Optional[torch.Tensor] = None
|
||||
next_token_ids: Optional[torch.Tensor] = None
|
||||
num_accepted_tokens: Optional[int] = None
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None
|
||||
can_run_cuda_graph: bool = False
|
||||
|
||||
|
||||
def enable_num_token_non_padded(server_args):
|
||||
return get_moe_expert_parallel_world_size() > 1
|
||||
|
||||
|
||||
@@ -370,6 +370,15 @@ class SamplingBatchInfo:
|
||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
def copy_for_forward(self):
|
||||
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
|
||||
self.update_penalties()
|
||||
return dataclasses.replace(
|
||||
self,
|
||||
sampling_info_done=threading.Event(),
|
||||
penalizer_orchestrator=None,
|
||||
)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
|
||||
@@ -19,11 +19,11 @@ from sglang.srt.managers.schedule_batch import (
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardBatchOutput,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -429,7 +429,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
def draft_model_runner(self):
|
||||
return self.model_runner
|
||||
|
||||
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
||||
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
|
||||
"""Run speculative decoding forward.
|
||||
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
||||
@@ -449,7 +449,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||
)
|
||||
return ForwardBatchOutput(
|
||||
return GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
num_accepted_tokens=0,
|
||||
@@ -472,7 +472,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# decode is not finished
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
|
||||
return ForwardBatchOutput(
|
||||
return GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=verify_output.verified_id,
|
||||
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
||||
@@ -513,12 +513,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
|
||||
logits_output, next_token_ids = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.next_token_ids,
|
||||
batch_result.logits_output,
|
||||
batch_result.next_token_ids,
|
||||
)
|
||||
return (
|
||||
logits_output,
|
||||
@@ -822,12 +820,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
).cpu()
|
||||
|
||||
# Forward
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
batch_result = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, is_verify=True
|
||||
)
|
||||
logits_output, can_run_cuda_graph = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
batch_result.logits_output,
|
||||
batch_result.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
vocab_mask = None
|
||||
|
||||
@@ -6,8 +6,9 @@ import torch
|
||||
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
@@ -207,18 +208,18 @@ class NGRAMWorker:
|
||||
batch_tokens.append(put_ids)
|
||||
self.ngram_cache.batch_put(batch_tokens)
|
||||
|
||||
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
||||
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
|
||||
self._prepare_for_speculative_decoding(batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
num_accepted_tokens = 0
|
||||
|
||||
if model_worker_batch.forward_mode.is_target_verify():
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
batch_result = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, is_verify=True
|
||||
)
|
||||
logits_output, can_run_cuda_graph = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
batch_result.logits_output,
|
||||
batch_result.can_run_cuda_graph,
|
||||
)
|
||||
verify_input = model_worker_batch.spec_info
|
||||
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
||||
@@ -228,16 +229,16 @@ class NGRAMWorker:
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
|
||||
else:
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
batch_result = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.next_token_ids,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
batch_result.logits_output,
|
||||
batch_result.next_token_ids,
|
||||
batch_result.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
return ForwardBatchOutput(
|
||||
return GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
|
||||
@@ -1160,7 +1160,7 @@ def run_bench_offline_throughput(model, other_args):
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
|
||||
print(f"{command=}")
|
||||
print(f"command={' '.join(command)}")
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user