Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import copy
import dataclasses
import logging
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -65,6 +69,8 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
@@ -230,6 +236,7 @@ class Req:
sampling_params: SamplingParams,
return_logprob: bool = False,
top_logprobs_num: int = 0,
token_ids_logprob: List[int] = None,
stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
@@ -256,17 +263,24 @@ class Req:
self.input_embeds = input_embeds
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
sampling_params.custom_params = sampling_params.custom_params | {
"__req__": self
}
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# Memory pool info
self.req_pool_idx = None
self.req_pool_idx: Optional[int] = None
# Check finish
self.tokenizer = None
self.finished_reason = None
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False
self.stream = stream
self.eos_token_ids = eos_token_ids
@@ -289,38 +303,56 @@ class Req:
self.image_inputs: Optional[ImageInputs] = None
# Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Updated if chunked.
# Number of tokens to run prefill.
self.extend_input_len = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None
# Chunked prefill
self.is_being_chunked = 0
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
# processed.
self.is_chunked = 0
# For retraction
self.is_retracted = False
# Logprobs (arguments)
self.return_logprob = return_logprob
# Start index to compute logprob from.
self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num
self.token_ids_logprob = token_ids_logprob
# Logprobs (return values)
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
self.input_top_logprobs_idx: Optional[List[int]] = None
self.input_token_ids_logprobs_val: Optional[List[float]] = None
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
# Temporary holder to store input_token_logprobs.
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
if return_logprob:
self.output_token_logprobs_val = []
self.output_token_logprobs_idx = []
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
self.output_token_ids_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
self.output_token_ids_logprobs_idx
) = None
self.hidden_states = []
# Logprobs (internal values)
@@ -345,6 +377,13 @@ class Req:
self.spec_verify_ct = 0
self.lora_path = lora_path
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error"
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs
@@ -422,7 +461,9 @@ class Req:
return
if self.to_abort:
self.finished_reason = FINISH_ABORT()
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
@@ -517,6 +558,8 @@ class Req:
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
@@ -527,16 +570,19 @@ class Req:
self.last_node = None
self.extend_input_len = 0
self.is_retracted = True
self.input_token_logprobs = None
self.temp_input_top_logprobs_val = None
self.temp_input_top_logprobs_idx = None
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self.last_update_decode_tokens = 0
self.logprob_start_len = 10**9
def __repr__(self):
return (
f"rid(n={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
f"Req(rid={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
)
@@ -576,11 +622,13 @@ class ScheduleBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
@@ -588,6 +636,8 @@ class ScheduleBatch:
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required.
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
@@ -606,7 +656,7 @@ class ScheduleBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
@@ -653,8 +703,10 @@ class ScheduleBatch:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
@@ -765,6 +817,7 @@ class ScheduleBatch:
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
input_embeds = []
extend_input_logprob_token_ids = []
pt = 0
for i, req in enumerate(reqs):
@@ -783,22 +836,64 @@ class ScheduleBatch:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
if req.return_logprob:
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
req.extend_logprob_start_len = 0
if self.return_logprob:
# Find input logprob token ids.
# First, find a global index within origin_input_ids and slide it by 1
# to compute input logprobs. It is because you need the next token
# to compute input logprobs. E.g., (chunk size 2)
#
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [1, 2]
# extend_input_logprob_token_id = [2, 3]
#
# Note that it can also overflow. In this case, we pad it with 0.
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [3, 4]
# extend_input_logprob_token_id = [4, 0]
global_start_idx, global_end_idx = (
len(req.prefix_indices),
len(req.fill_ids),
)
# Apply logprob_start_len
if global_start_idx < req.logprob_start_len:
global_start_idx = req.logprob_start_len
logprob_token_ids = req.origin_input_ids[
global_start_idx + 1 : global_end_idx + 1
]
extend_input_logprob_token_ids.extend(logprob_token_ids)
# We will need req.extend_input_len - req.extend_logprob_start_len number of
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
extend_input_logprob_token_ids.extend(
[0]
* (
req.extend_input_len
- req.extend_logprob_start_len
- len(logprob_token_ids)
)
)
if self.return_logprob:
extend_input_logprob_token_ids = torch.tensor(
extend_input_logprob_token_ids
)
else:
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -821,10 +916,12 @@ class ScheduleBatch:
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Write to req_to_token_pool
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
@@ -860,7 +957,6 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -905,25 +1001,43 @@ class ScheduleBatch:
return False
def retract_decode(self):
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve retraction policy for radix cache
sorted_indices.sort(
key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
# For spec decoding, filter_batch API can only filter
# requests from the back, so we can only retract from the back.
# TODO(sang): Clean up finish path and support better retract
# policy.
if not server_args.speculative_algorithm:
sorted_indices.sort(
key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
headroom_for_spec_decode += (
num_reqs
* server_args.speculative_eagle_topk
* server_args.speculative_num_steps
+ num_reqs * server_args.speculative_num_draft_tokens
)
return (
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps
< get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
@@ -1048,17 +1162,40 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
if self.sampling_info.penalizer_orchestrator.is_required:
if self.enable_overlap:
# TODO: this can be slow, optimize this.
delayed_output_ids = torch.tensor(
[
(
req.output_ids[-1]
if len(req.output_ids)
else req.origin_input_ids[-1]
)
for req in self.reqs
],
dtype=torch.int64,
device=self.device,
)
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
delayed_output_ids
)
else:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.output_ids.to(torch.int64)
)
self.input_ids = self.output_ids
self.output_ids = None
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
# Alloc mem
bs = len(self.reqs)
@@ -1086,14 +1223,15 @@ class ScheduleBatch:
def filter_batch(
self,
being_chunked_req: Optional[Req] = None,
chunked_req_to_exclude: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
if not self.reqs[i].finished()
and self.reqs[i] is not chunked_req_to_exclude
]
if keep_indices is None or len(keep_indices) == 0:
@@ -1105,31 +1243,34 @@ class ScheduleBatch:
# No need to filter
return
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices]
self.encoder_lens = self.encoder_lens[keep_indices_device]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[new_indices]
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
else:
self.top_logprobs_nums = None
self.token_ids_logprobs = None
self.has_stream = any(req.stream for req in self.reqs)
self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices)
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
self.spec_info.filter_batch(new_indices)
self.spec_info.filter_batch(keep_indices_device)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1152,10 +1293,13 @@ class ScheduleBatch:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.token_ids_logprobs.extend(other.token_ids_logprobs)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
self.token_ids_logprobs.extend([None] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs)
self.return_logprob |= other.return_logprob
@@ -1192,7 +1336,9 @@ class ScheduleBatch:
seq_lens_sum=self.seq_lens_sum,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
@@ -1219,6 +1365,7 @@ class ScheduleBatch:
else CaptureHiddenMode.NULL
)
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
)
def copy(self):
@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
token_ids_logprobs: Optional[List[List[int]]]
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For extend
@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]
extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal
image_inputs: Optional[List[ImageInputs]]
@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None