Remove overlap thread (#11210)
Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -25,12 +25,14 @@ from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Deque, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
import torch
|
||||
import zmq
|
||||
from torch.cuda import Stream as CudaStream
|
||||
from torch.cuda import StreamContext as CudaStreamContext
|
||||
from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
from sglang.srt.managers.overlap_utils import FutureMap
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ModelWorkerBatch,
|
||||
MultimodalInputs,
|
||||
Req,
|
||||
RequestStage,
|
||||
@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||
SchedulerUpdateWeightsMixin,
|
||||
)
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatchOutput,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: Optional[LogitsProcessorOutput]
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
||||
next_token_ids: Optional[List[int]]
|
||||
can_run_cuda_graph: bool
|
||||
logits_output: Optional[LogitsProcessorOutput] = None
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
||||
next_token_ids: Optional[torch.Tensor] = None
|
||||
num_accepted_tokens: Optional[int] = None
|
||||
can_run_cuda_graph: bool = False
|
||||
|
||||
# For output processing
|
||||
extend_input_len_per_req: List[int]
|
||||
extend_logprob_start_len_per_req: List[int]
|
||||
extend_input_len_per_req: Optional[List[int]] = None
|
||||
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch_output(
|
||||
cls,
|
||||
forward_batch_output: ForwardBatchOutput,
|
||||
extend_input_len_per_req: List[int],
|
||||
extend_logprob_start_len_per_req: List[int],
|
||||
):
|
||||
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
||||
# For overlap scheduling
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
delay_sample_launch: bool = False
|
||||
forward_batch: Optional[ForwardBatch] = None
|
||||
future_map_ct: Optional[int] = None
|
||||
|
||||
return cls(
|
||||
logits_output=forward_batch_output.logits_output,
|
||||
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
||||
next_token_ids=forward_batch_output.next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
Only the tensors which are needed for processing results are copied,
|
||||
e.g., next_token_ids, logits outputs
|
||||
"""
|
||||
if return_logprob:
|
||||
if self.logits_output.next_token_logits is not None:
|
||||
self.logits_output.next_token_logits = (
|
||||
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.input_token_logprobs is not None:
|
||||
self.logits_output.input_token_logprobs = (
|
||||
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.hidden_states is not None:
|
||||
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||
self.copy_done.record()
|
||||
|
||||
@classmethod
|
||||
def from_pp_proxy(
|
||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||
):
|
||||
# TODO(lsyin): also simplify this logic
|
||||
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
||||
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
||||
# TODO(lsyin): refactor PP and avoid using dict
|
||||
proxy_dict = next_pp_outputs.tensors
|
||||
return cls(
|
||||
logits_output=logits_output,
|
||||
@@ -388,12 +398,10 @@ class Scheduler(
|
||||
logger.info("Overlap scheduler is disabled for embedding models.")
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
|
||||
self.tp_worker = TpWorkerClass(
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
|
||||
self.tp_worker = TpModelWorker(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
@@ -525,9 +533,11 @@ class Scheduler(
|
||||
self.kv_transfer_speed_gb_s: float = 0.0
|
||||
self.kv_transfer_latency_ms: float = 0.0
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
self.default_stream: CudaStream = torch.get_device_module(
|
||||
self.device
|
||||
).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.current_stream.synchronize = lambda: None # No-op for CPU
|
||||
self.default_stream.synchronize = lambda: None # No-op for CPU
|
||||
self.forward_sleep_time = None
|
||||
|
||||
# Init chunked prefill
|
||||
@@ -618,6 +628,9 @@ class Scheduler(
|
||||
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
||||
self.init_deterministic_inference_config()
|
||||
|
||||
# Init overlap
|
||||
self.init_overlap()
|
||||
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -932,6 +945,32 @@ class Scheduler(
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||
|
||||
def init_overlap(self):
|
||||
if not self.enable_overlap:
|
||||
return
|
||||
|
||||
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
|
||||
self.device
|
||||
).stream(self.forward_stream)
|
||||
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
|
||||
self.device
|
||||
).stream(self.copy_stream)
|
||||
|
||||
self.future_map = FutureMap(self.max_running_requests, self.device)
|
||||
self.batch_record_buf = [None] * 2
|
||||
self.batch_record_ct = 0
|
||||
|
||||
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
|
||||
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
|
||||
# NOTE: More Reliable: record all tensors into the forward stream
|
||||
# NOTE: - for all future tensors, we shall always read from future map
|
||||
# - for all non-future tensors (produced only by schedule stream),
|
||||
# we shall keep its reference not being release during all the forwarding pass
|
||||
self.batch_record_ct = (self.batch_record_ct + 1) % 2
|
||||
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
|
||||
|
||||
def init_moe_config(self):
|
||||
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
||||
initialize_moe_config(self.server_args)
|
||||
@@ -958,9 +997,11 @@ class Scheduler(
|
||||
@DynamicGradMode()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
self.result_queue = deque()
|
||||
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
@@ -968,7 +1009,6 @@ class Scheduler(
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
batch.launch_done = threading.Event()
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
@@ -980,7 +1020,7 @@ class Scheduler(
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
@@ -988,10 +1028,7 @@ class Scheduler(
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
||||
self.process_batch_result(
|
||||
tmp_batch, tmp_result, batch.launch_done if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.self_check_during_idle()
|
||||
@@ -2056,18 +2093,62 @@ class Scheduler(
|
||||
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
forward_batch_output = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
if self.enable_overlap:
|
||||
# FIXME: remove this assert
|
||||
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
|
||||
model_worker_batch = batch_or_worker_batch
|
||||
self.record_batch_in_overlap(model_worker_batch)
|
||||
|
||||
# Sampling info will be modified during forward
|
||||
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
|
||||
model_worker_batch.sampling_info.copy_for_forward()
|
||||
)
|
||||
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
cur_future_map_ct = self.future_map.update_ct(bs)
|
||||
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
self.future_map.resolve_future(model_worker_batch)
|
||||
if batch.sampling_info.grammars is not None:
|
||||
model_worker_batch.delay_sample_launch = True
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
# FIXME(lsyin): maybe move this to forward_batch_generation
|
||||
batch_result.copy_done = torch.get_device_module(
|
||||
self.device
|
||||
).Event()
|
||||
if not model_worker_batch.delay_sample_launch:
|
||||
self.future_map.store_to_map(
|
||||
cur_future_map_ct, bs, batch_result.next_token_ids
|
||||
)
|
||||
batch_result.copy_to_cpu()
|
||||
else:
|
||||
batch_result.future_map_ct = cur_future_map_ct
|
||||
|
||||
# FIXME(lsyin): move this assignment elsewhere
|
||||
maybe_future_next_token_ids = self.future_map.update_next_future(
|
||||
cur_future_map_ct, bs
|
||||
)
|
||||
else:
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
maybe_future_next_token_ids = batch_result.next_token_ids
|
||||
copy_done = None
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
self.udpate_spec_metrics(
|
||||
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
||||
self.update_spec_metrics(
|
||||
batch.batch_size(), batch_result.num_accepted_tokens
|
||||
)
|
||||
|
||||
# update batch's output ids
|
||||
batch.output_ids = forward_batch_output.next_token_ids
|
||||
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
|
||||
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
||||
# we shall still keep the original outputs, e.g. next_token_ids
|
||||
# in the GenerationBatchOutput for processing after copy_done.
|
||||
batch.output_ids = maybe_future_next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -2084,36 +2165,60 @@ class Scheduler(
|
||||
else:
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
return GenerationBatchResult.from_forward_batch_output(
|
||||
forward_batch_output=forward_batch_output,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
batch_result.extend_input_len_per_req = extend_input_len_per_req
|
||||
batch_result.extend_logprob_start_len_per_req = (
|
||||
extend_logprob_start_len_per_req
|
||||
)
|
||||
return batch_result
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
return ret
|
||||
|
||||
def launch_last_batch_sample_if_needed(
|
||||
self,
|
||||
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
||||
if len(self.result_queue) == 0:
|
||||
return
|
||||
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
|
||||
tmp_result: GenerationBatchResult
|
||||
if not tmp_result.delay_sample_launch:
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
return
|
||||
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
|
||||
tmp_result.logits_output,
|
||||
tmp_result.forward_batch,
|
||||
)
|
||||
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
|
||||
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
|
||||
tmp_result.copy_to_cpu()
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
|
||||
def process_batch_result(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result, launch_done)
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if self.enable_trace:
|
||||
trace_slice_batch("decode loop", batch.reqs)
|
||||
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result, launch_done)
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
if self.enable_trace:
|
||||
trace_slice_batch("prefill", batch.reqs)
|
||||
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
if result.copy_done is not None:
|
||||
result.copy_done.synchronize()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
@@ -2330,7 +2435,7 @@ class Scheduler(
|
||||
if batch.next_batch_sampling_info:
|
||||
if batch.next_batch_sampling_info.grammars is not None:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
self.default_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def watchdog_thread(self):
|
||||
|
||||
Reference in New Issue
Block a user