Remove overlap thread (#11210)

Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-10-07 20:12:12 +08:00
committed by GitHub
parent 24bc3fb0f9
commit 1519a89cfd
14 changed files with 280 additions and 184 deletions

View File

@@ -25,12 +25,14 @@ from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil
import setproctitle
import torch
import zmq
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
from torch.distributed import barrier
from sglang.global_config import global_config
@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ModelWorkerBatch,
MultimodalInputs,
Req,
RequestStage,
@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatchOutput,
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
@dataclass
class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
next_token_ids: Optional[List[int]]
can_run_cuda_graph: bool
logits_output: Optional[LogitsProcessorOutput] = None
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
can_run_cuda_graph: bool = False
# For output processing
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
extend_input_len_per_req: Optional[List[int]] = None
extend_logprob_start_len_per_req: Optional[List[int]] = None
@classmethod
def from_forward_batch_output(
cls,
forward_batch_output: ForwardBatchOutput,
extend_input_len_per_req: List[int],
extend_logprob_start_len_per_req: List[int],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
# For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None
delay_sample_launch: bool = False
forward_batch: Optional[ForwardBatch] = None
future_map_ct: Optional[int] = None
return cls(
logits_output=forward_batch_output.logits_output,
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
next_token_ids=forward_batch_output.next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
)
def copy_to_cpu(self, return_logprob: bool = False):
"""Copy tensors to CPU in overlap scheduling.
Only the tensors which are needed for processing results are copied,
e.g., next_token_ids, logits outputs
"""
if return_logprob:
if self.logits_output.next_token_logits is not None:
self.logits_output.next_token_logits = (
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
)
if self.logits_output.input_token_logprobs is not None:
self.logits_output.input_token_logprobs = (
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if self.logits_output.hidden_states is not None:
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
"cpu", non_blocking=True
)
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
self.copy_done.record()
@classmethod
def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
# TODO(lsyin): refactor PP and avoid using dict
proxy_dict = next_pp_outputs.tensors
return cls(
logits_output=logits_output,
@@ -388,12 +398,10 @@ class Scheduler(
logger.info("Overlap scheduler is disabled for embedding models.")
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
from sglang.srt.managers.tp_worker import TpModelWorker
self.tp_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
@@ -525,9 +533,11 @@ class Scheduler(
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream()
self.default_stream: CudaStream = torch.get_device_module(
self.device
).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
self.default_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init chunked prefill
@@ -618,6 +628,9 @@ class Scheduler(
# Init prefill kv split size when deterministic inference is enabled with various attention backends
self.init_deterministic_inference_config()
# Init overlap
self.init_overlap()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
@@ -932,6 +945,32 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
def init_overlap(self):
if not self.enable_overlap:
return
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
self.device
).stream(self.forward_stream)
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
self.device
).stream(self.copy_stream)
self.future_map = FutureMap(self.max_running_requests, self.device)
self.batch_record_buf = [None] * 2
self.batch_record_ct = 0
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
# NOTE: More Reliable: record all tensors into the forward stream
# NOTE: - for all future tensors, we shall always read from future map
# - for all non-future tensors (produced only by schedule stream),
# we shall keep its reference not being release during all the forwarding pass
self.batch_record_ct = (self.batch_record_ct + 1) % 2
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
@@ -958,9 +997,11 @@ class Scheduler(
@DynamicGradMode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
self.result_queue = deque()
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
@@ -968,7 +1009,6 @@ class Scheduler(
self.cur_batch = batch
if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
@@ -980,7 +1020,7 @@ class Scheduler(
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None, batch.launch_done)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
# Process the results of the last batch
@@ -988,10 +1028,7 @@ class Scheduler(
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.self_check_during_idle()
@@ -2056,18 +2093,62 @@ class Scheduler(
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
forward_batch_output = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
if self.enable_overlap:
# FIXME: remove this assert
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
model_worker_batch = batch_or_worker_batch
self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
self.future_map.resolve_future(model_worker_batch)
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = torch.get_device_module(
self.device
).Event()
if not model_worker_batch.delay_sample_launch:
self.future_map.store_to_map(
cur_future_map_ct, bs, batch_result.next_token_ids
)
batch_result.copy_to_cpu()
else:
batch_result.future_map_ct = cur_future_map_ct
# FIXME(lsyin): move this assignment elsewhere
maybe_future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
else:
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
maybe_future_next_token_ids = batch_result.next_token_ids
copy_done = None
if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens
self.update_spec_metrics(
batch.batch_size(), batch_result.num_accepted_tokens
)
# update batch's output ids
batch.output_ids = forward_batch_output.next_token_ids
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
# which can probably be replaced by future_indices later [TODO(lsyin)].
# we shall still keep the original outputs, e.g. next_token_ids
# in the GenerationBatchOutput for processing after copy_done.
batch.output_ids = maybe_future_next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
@@ -2084,36 +2165,60 @@ class Scheduler(
else:
extend_logprob_start_len_per_req = None
return GenerationBatchResult.from_forward_batch_output(
forward_batch_output=forward_batch_output,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
batch_result.extend_input_len_per_req = extend_input_len_per_req
batch_result.extend_logprob_start_len_per_req = (
extend_logprob_start_len_per_req
)
return batch_result
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings)
return ret
def launch_last_batch_sample_if_needed(
self,
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
if len(self.result_queue) == 0:
return
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_result: GenerationBatchResult
if not tmp_result.delay_sample_launch:
self.result_queue.appendleft((tmp_batch, tmp_result))
return
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
tmp_result.logits_output,
tmp_result.forward_batch,
)
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))
def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done)
self.process_batch_result_decode(batch, result)
if self.enable_trace:
trace_slice_batch("decode loop", batch.reqs)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done)
self.process_batch_result_prefill(batch, result)
if self.enable_trace:
trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
if result.copy_done is not None:
result.copy_done.synchronize()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
@@ -2330,7 +2435,7 @@ class Scheduler(
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
self.default_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self):