Make constrained decoding work for overlap scheduler (#2095)

This commit is contained in:
Lianmin Zheng
2024-11-19 15:04:43 -08:00
committed by GitHub
parent 55bd97f3e5
commit ffd20fcd03
8 changed files with 119 additions and 95 deletions

View File

@@ -15,6 +15,7 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
import dataclasses
import logging
import os
import threading
@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
@@ -220,8 +222,12 @@ class Scheduler:
# Init running status
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None
# The current forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
@@ -336,15 +342,12 @@ class Scheduler:
@torch.no_grad()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
"""A normal scheduler loop."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
@@ -353,20 +356,8 @@ class Scheduler:
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
# Decode multiple steps to reduce the overhead
if batch.forward_mode.is_decode():
for _ in range(self.server_args.num_continuous_decode_steps - 1):
if not self.running_batch:
break
self.update_running_batch()
if not self.running_batch:
break
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# Self-check and re-init some states when the server is idle
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
@@ -377,9 +368,6 @@ class Scheduler:
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
self.last_batch = None
self.running_batch = None
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
@@ -390,10 +378,24 @@ class Scheduler:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if self.last_batch is None:
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# Self-check and re-init some states when the server is idle
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
@@ -806,7 +808,7 @@ class Scheduler:
self.tree_cache,
self.model_config,
)
new_batch.prepare_for_extend()
new_batch.prepare_for_extend(self.enable_overlap)
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
@@ -893,14 +895,15 @@ class Scheduler:
return ret
def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_idle():
return
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
else:
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
@@ -953,6 +956,10 @@ class Scheduler:
else:
req.is_being_chunked -= 1
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
embeddings, bid = result
embeddings = embeddings.tolist()
@@ -1022,6 +1029,10 @@ class Scheduler:
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
self.token_to_kv_pool.free_group_end()