Make constrained decoding work for overlap scheduler (#2095)
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ class ImageInputs:
|
|||||||
image_embeds: Optional[List[torch.Tensor]] = None
|
image_embeds: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
# QWen2-VL related
|
# QWen2-VL related
|
||||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||||
mrope_position_delta: Optional[torch.Tensor] = None
|
mrope_position_delta: Optional[torch.Tensor] = None
|
||||||
@@ -187,11 +188,10 @@ class Req:
|
|||||||
self.origin_input_ids = origin_input_ids
|
self.origin_input_ids = origin_input_ids
|
||||||
self.output_ids = [] # Each decode stage's output ids
|
self.output_ids = [] # Each decode stage's output ids
|
||||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||||
|
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
|
||||||
# Memory info
|
# Memory pool info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
|
|
||||||
# Check finish
|
# Check finish
|
||||||
@@ -428,7 +428,7 @@ bid = 0
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ScheduleBatch:
|
class ScheduleBatch:
|
||||||
"""Store all inforamtion of a batch."""
|
"""Store all inforamtion of a batch on the scheduler."""
|
||||||
|
|
||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
@@ -438,9 +438,9 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
# For utility
|
# For utility
|
||||||
model_config: ModelConfig = None
|
model_config: ModelConfig = None
|
||||||
|
|
||||||
forward_mode: ForwardMode = None
|
forward_mode: ForwardMode = None
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
|
next_batch_sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
@@ -509,7 +509,7 @@ class ScheduleBatch:
|
|||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|
||||||
def alloc_req_slots(self, num_reqs):
|
def alloc_req_slots(self, num_reqs: int):
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||||
if req_pool_indices is None:
|
if req_pool_indices is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -610,7 +610,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||||
|
|
||||||
def prepare_for_extend(self):
|
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
|
||||||
self.forward_mode = ForwardMode.EXTEND
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
@@ -704,7 +704,7 @@ class ScheduleBatch:
|
|||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
self,
|
self,
|
||||||
self.model_config.vocab_size,
|
self.model_config.vocab_size,
|
||||||
global_server_args_dict["disable_penalizer"],
|
enable_overlap_schedule=enable_overlap_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
@@ -746,6 +746,7 @@ class ScheduleBatch:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def retract_decode(self):
|
def retract_decode(self):
|
||||||
|
"""Retract the decoding requests when there is not enough memory."""
|
||||||
sorted_indices = [i for i in range(len(self.reqs))]
|
sorted_indices = [i for i in range(len(self.reqs))]
|
||||||
|
|
||||||
# TODO(lsyin): improve retraction policy for radix cache
|
# TODO(lsyin): improve retraction policy for radix cache
|
||||||
@@ -886,18 +887,10 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def prepare_for_idle(self):
|
def prepare_for_idle(self):
|
||||||
self.forward_mode = ForwardMode.IDLE
|
self.forward_mode = ForwardMode.IDLE
|
||||||
self.input_ids = torch.empty(0, dtype=torch.int32).to(
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
self.device, non_blocking=True
|
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
)
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
|
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
self.seq_lens_sum = 0
|
self.seq_lens_sum = 0
|
||||||
self.extend_num_tokens = 0
|
self.extend_num_tokens = 0
|
||||||
|
|
||||||
@@ -1063,7 +1056,6 @@ class ScheduleBatch:
|
|||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
sampling_info=self.sampling_info,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
|||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
@@ -220,8 +222,12 @@ class Scheduler:
|
|||||||
|
|
||||||
# Init running status
|
# Init running status
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
|
# The running decoding batch for continuous batching
|
||||||
self.running_batch: Optional[ScheduleBatch] = None
|
self.running_batch: Optional[ScheduleBatch] = None
|
||||||
|
# The current forward batch
|
||||||
self.cur_batch: Optional[ScheduleBatch] = None
|
self.cur_batch: Optional[ScheduleBatch] = None
|
||||||
|
# The current forward batch
|
||||||
|
self.last_batch: Optional[ScheduleBatch] = None
|
||||||
self.forward_ct = 0
|
self.forward_ct = 0
|
||||||
self.forward_ct_decode = 0
|
self.forward_ct_decode = 0
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
@@ -336,15 +342,12 @@ class Scheduler:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
"""A normal blocking scheduler loop."""
|
"""A normal scheduler loop."""
|
||||||
self.last_batch = None
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
recv_reqs = self.recv_requests()
|
recv_reqs = self.recv_requests()
|
||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
|
|
||||||
if self.server_args.enable_dp_attention:
|
if self.server_args.enable_dp_attention:
|
||||||
batch = self.prepare_dp_attn_batch(batch)
|
batch = self.prepare_dp_attn_batch(batch)
|
||||||
|
|
||||||
@@ -353,20 +356,8 @@ class Scheduler:
|
|||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
# Decode multiple steps to reduce the overhead
|
|
||||||
if batch.forward_mode.is_decode():
|
|
||||||
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
|
||||||
if not self.running_batch:
|
|
||||||
break
|
|
||||||
self.update_running_batch()
|
|
||||||
if not self.running_batch:
|
|
||||||
break
|
|
||||||
if self.server_args.enable_dp_attention:
|
|
||||||
batch = self.prepare_dp_attn_batch(batch)
|
|
||||||
result = self.run_batch(batch)
|
|
||||||
self.process_batch_result(batch, result)
|
|
||||||
else:
|
else:
|
||||||
|
# Self-check and re-init some states when the server is idle
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
@@ -377,9 +368,6 @@ class Scheduler:
|
|||||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||||
result_queue = deque()
|
result_queue = deque()
|
||||||
|
|
||||||
self.last_batch = None
|
|
||||||
self.running_batch = None
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
recv_reqs = self.recv_requests()
|
recv_reqs = self.recv_requests()
|
||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
@@ -390,10 +378,24 @@ class Scheduler:
|
|||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
result_queue.append((batch.copy(), result))
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
|
if self.last_batch is None:
|
||||||
|
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||||
|
# It is now used for triggering the sampling_info_done event.
|
||||||
|
tmp_batch = ScheduleBatch(
|
||||||
|
reqs=None,
|
||||||
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||||
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||||
|
)
|
||||||
|
self.process_batch_result(tmp_batch, None)
|
||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
tmp_batch, tmp_result = result_queue.popleft()
|
tmp_batch, tmp_result = result_queue.popleft()
|
||||||
|
tmp_batch.next_batch_sampling_info = (
|
||||||
|
self.tp_worker.cur_sampling_info if batch else None
|
||||||
|
)
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
|
# Self-check and re-init some states when the server is idle
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
@@ -806,7 +808,7 @@ class Scheduler:
|
|||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend(self.enable_overlap)
|
||||||
|
|
||||||
# Mixed-style chunked prefill
|
# Mixed-style chunked prefill
|
||||||
if self.is_mixed_chunk and self.running_batch is not None:
|
if self.is_mixed_chunk and self.running_batch is not None:
|
||||||
@@ -893,14 +895,15 @@ class Scheduler:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||||
if batch.forward_mode.is_idle():
|
|
||||||
return
|
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
else:
|
elif batch.forward_mode.is_extend():
|
||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result)
|
||||||
|
elif batch.forward_mode.is_dummy_first():
|
||||||
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
|
|
||||||
@@ -953,6 +956,10 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
|
if batch.next_batch_sampling_info:
|
||||||
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings, bid = result
|
embeddings, bid = result
|
||||||
embeddings = embeddings.tolist()
|
embeddings = embeddings.tolist()
|
||||||
@@ -1022,6 +1029,10 @@ class Scheduler:
|
|||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
|
if batch.next_batch_sampling_info:
|
||||||
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
|
|
||||||
self.token_to_kv_pool.free_group_end()
|
self.token_to_kv_pool.free_group_end()
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -96,9 +95,7 @@ class TpModelWorkerClient:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
while True:
|
while True:
|
||||||
model_worker_batch, future_token_ids_ct, compute_info_done = (
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||||
self.input_queue.get()
|
|
||||||
)
|
|
||||||
if not model_worker_batch:
|
if not model_worker_batch:
|
||||||
break
|
break
|
||||||
self.launch_done = threading.Event()
|
self.launch_done = threading.Event()
|
||||||
@@ -109,7 +106,6 @@ class TpModelWorkerClient:
|
|||||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
compute_info_done.wait()
|
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||||
model_worker_batch, self.launch_done
|
model_worker_batch, self.launch_done
|
||||||
)
|
)
|
||||||
@@ -160,15 +156,16 @@ class TpModelWorkerClient:
|
|||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||||
|
_ = model_worker_batch.seq_lens[0].item()
|
||||||
|
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
model_worker_batch.sampling_info = dataclasses.replace(
|
model_worker_batch.sampling_info = dataclasses.replace(
|
||||||
model_worker_batch.sampling_info
|
model_worker_batch.sampling_info,
|
||||||
)
|
sampling_info_done=threading.Event(),
|
||||||
compute_info_done = torch.cuda.Event()
|
|
||||||
compute_info_done.record()
|
|
||||||
self.input_queue.put(
|
|
||||||
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
|
|
||||||
)
|
)
|
||||||
|
self.cur_sampling_info = model_worker_batch.sampling_info
|
||||||
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||||
|
|
||||||
# Allocate output future objects
|
# Allocate output future objects
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
|
|||||||
@@ -52,15 +52,19 @@ if TYPE_CHECKING:
|
|||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||||
PREFILL = auto()
|
PREFILL = auto()
|
||||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
# Decode one token.
|
# Decode one token.
|
||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
# Contains both EXTEND and DECODE.
|
# Contains both EXTEND and DECODE when doing chunked prefill.
|
||||||
MIXED = auto()
|
MIXED = auto()
|
||||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
|
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
||||||
IDLE = auto()
|
IDLE = auto()
|
||||||
|
|
||||||
|
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||||
|
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||||
|
DUMMY_FIRST = auto()
|
||||||
|
|
||||||
def is_prefill(self):
|
def is_prefill(self):
|
||||||
return self == ForwardMode.PREFILL
|
return self == ForwardMode.PREFILL
|
||||||
|
|
||||||
@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
|
|||||||
def is_idle(self):
|
def is_idle(self):
|
||||||
return self == ForwardMode.IDLE
|
return self == ForwardMode.IDLE
|
||||||
|
|
||||||
|
def is_dummy_first(self):
|
||||||
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardBatch:
|
class ForwardBatch:
|
||||||
|
|||||||
@@ -142,7 +142,6 @@ class ModelRunner:
|
|||||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||||
"disable_mla": server_args.disable_mla,
|
"disable_mla": server_args.disable_mla,
|
||||||
"torchao_config": server_args.torchao_config,
|
"torchao_config": server_args.torchao_config,
|
||||||
"disable_penalizer": server_args.disable_penalizer,
|
|
||||||
"enable_nan_detection": server_args.enable_nan_detection,
|
"enable_nan_detection": server_args.enable_nan_detection,
|
||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
}
|
}
|
||||||
@@ -636,10 +635,18 @@ class ModelRunner:
|
|||||||
def sample(
|
def sample(
|
||||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
|
||||||
sampling_info = forward_batch.sampling_info
|
sampling_info = forward_batch.sampling_info
|
||||||
sampling_info.update_regex_vocab_mask()
|
|
||||||
sampling_info.update_penalties()
|
if sampling_info.sampling_info_done:
|
||||||
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||||
|
# in process_batch_result of the last batch.
|
||||||
|
if sampling_info.grammars:
|
||||||
|
sampling_info.sampling_info_done.wait()
|
||||||
|
sampling_info.update_penalties()
|
||||||
|
else:
|
||||||
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||||
|
sampling_info.update_regex_vocab_mask()
|
||||||
|
sampling_info.update_penalties()
|
||||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|
||||||
@@ -28,6 +33,7 @@ class SamplingBatchInfo:
|
|||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
|
sampling_info_done: Optional[threading.Event] = None
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||||
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
cls,
|
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
||||||
batch: ScheduleBatch,
|
|
||||||
vocab_size: int,
|
|
||||||
disable_penalizer: bool,
|
|
||||||
):
|
):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
device = batch.device
|
device = batch.device
|
||||||
@@ -79,6 +82,33 @@ class SamplingBatchInfo:
|
|||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
|
|
||||||
|
if enable_overlap_schedule:
|
||||||
|
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
||||||
|
# so it is kind of tricky to make it work with overlap scheduler.
|
||||||
|
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
||||||
|
# We will support them later.
|
||||||
|
penalizers = {
|
||||||
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
||||||
|
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
||||||
|
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
||||||
|
"when using the default overlap scheduler. They will be ignored. "
|
||||||
|
"Please add `--disable-overlap` when launching the server if you need these features. "
|
||||||
|
"The speed will be slower in that case."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
penalizers = {
|
||||||
|
penaltylib.BatchedFrequencyPenalizer,
|
||||||
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
|
penaltylib.BatchedPresencePenalizer,
|
||||||
|
penaltylib.BatchedRepetitionPenalizer,
|
||||||
|
}
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||||
# should not add hefty computation overhead other than simple checks.
|
# should not add hefty computation overhead other than simple checks.
|
||||||
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
|
|||||||
# While we choose not to even create the class instances if they are not required, this
|
# While we choose not to even create the class instances if they are not required, this
|
||||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||||
if disable_penalizer:
|
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||||
ret.penalizer_orchestrator = None
|
vocab_size=vocab_size,
|
||||||
else:
|
batch=batch,
|
||||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
device=batch.device,
|
||||||
vocab_size=vocab_size,
|
Penalizers=penalizers,
|
||||||
batch=batch,
|
)
|
||||||
device=batch.device,
|
|
||||||
Penalizers={
|
|
||||||
penaltylib.BatchedFrequencyPenalizer,
|
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
|
||||||
penaltylib.BatchedPresencePenalizer,
|
|
||||||
penaltylib.BatchedRepetitionPenalizer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle logit bias but only allocate when needed
|
# Handle logit bias but only allocate when needed
|
||||||
ret.logit_bias = None
|
ret.logit_bias = None
|
||||||
@@ -133,13 +155,13 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
if not self.grammars or not any(grammar for grammar in self.grammars):
|
if not self.grammars:
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
self.apply_mask = None
|
self.apply_mask = None
|
||||||
return
|
return
|
||||||
|
|
||||||
# find a grammar from the list
|
# find a grammar from the list
|
||||||
grammar = next(grammar for grammar in self.grammars if grammar is not None)
|
grammar = next(grammar for grammar in self.grammars if grammar)
|
||||||
|
|
||||||
# maybe we can reuse the existing mask?
|
# maybe we can reuse the existing mask?
|
||||||
self.vocab_mask = grammar.allocate_vocab_mask(
|
self.vocab_mask = grammar.allocate_vocab_mask(
|
||||||
|
|||||||
@@ -123,7 +123,6 @@ class ServerArgs:
|
|||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
disable_mla: bool = False
|
disable_mla: bool = False
|
||||||
disable_penalizer: bool = False
|
|
||||||
enable_overlap_schedule: bool = False
|
enable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
@@ -200,12 +199,7 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_overlap_schedule:
|
if self.enable_overlap_schedule:
|
||||||
logger.warning(
|
self.disable_jump_forward = True
|
||||||
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
|
||||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
|
||||||
"and embedding APIs are not supported and will lead to wrong results. "
|
|
||||||
)
|
|
||||||
self.disable_penalizer = True
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -622,11 +616,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--disable-penalizer",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-nan-detection",
|
"--disable-nan-detection",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user