Make constrained decoding work for overlap scheduler (#2095)
This commit is contained in:
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
@@ -220,8 +222,12 @@ class Scheduler:
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
self.cur_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
@@ -336,15 +342,12 @@ class Scheduler:
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
self.last_batch = None
|
||||
|
||||
"""A normal scheduler loop."""
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
@@ -353,20 +356,8 @@ class Scheduler:
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
# Decode multiple steps to reduce the overhead
|
||||
if batch.forward_mode.is_decode():
|
||||
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
||||
if not self.running_batch:
|
||||
break
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
break
|
||||
if self.server_args.enable_dp_attention:
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -377,9 +368,6 @@ class Scheduler:
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
|
||||
self.last_batch = None
|
||||
self.running_batch = None
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
@@ -390,10 +378,24 @@ class Scheduler:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -806,7 +808,7 @@ class Scheduler:
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
new_batch.prepare_for_extend(self.enable_overlap)
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
@@ -893,14 +895,15 @@ class Scheduler:
|
||||
return ret
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
else:
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
|
||||
@@ -953,6 +956,10 @@ class Scheduler:
|
||||
else:
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings, bid = result
|
||||
embeddings = embeddings.tolist()
|
||||
@@ -1022,6 +1029,10 @@ class Scheduler:
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
|
||||
Reference in New Issue
Block a user