Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
@@ -65,6 +69,8 @@ global_server_args_dict = {
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
@@ -230,6 +236,7 @@ class Req:
|
||||
sampling_params: SamplingParams,
|
||||
return_logprob: bool = False,
|
||||
top_logprobs_num: int = 0,
|
||||
token_ids_logprob: List[int] = None,
|
||||
stream: bool = False,
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
@@ -256,17 +263,24 @@ class Req:
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
sampling_params.custom_params = sampling_params.custom_params | {
|
||||
"__req__": self
|
||||
}
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
# If we want to abort the request in the middle of the event loop, set this to true
|
||||
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
||||
self.to_abort = False
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
@@ -289,38 +303,56 @@ class Req:
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
|
||||
# Prefix info
|
||||
# The indices to kv cache for the shared prefix.
|
||||
self.prefix_indices = []
|
||||
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
||||
# Updated if chunked.
|
||||
# Number of tokens to run prefill.
|
||||
self.extend_input_len = 0
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
self.last_node = None
|
||||
|
||||
# Chunked prefill
|
||||
self.is_being_chunked = 0
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
# processed.
|
||||
self.is_chunked = 0
|
||||
|
||||
# For retraction
|
||||
self.is_retracted = False
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = return_logprob
|
||||
# Start index to compute logprob from.
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
self.token_ids_logprob = token_ids_logprob
|
||||
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
self.input_top_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
# Temporary holder to store input_token_logprobs.
|
||||
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
|
||||
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
|
||||
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
|
||||
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
|
||||
if return_logprob:
|
||||
self.output_token_logprobs_val = []
|
||||
self.output_token_logprobs_idx = []
|
||||
self.output_top_logprobs_val = []
|
||||
self.output_top_logprobs_idx = []
|
||||
self.output_token_ids_logprobs_val = []
|
||||
self.output_token_ids_logprobs_idx = []
|
||||
else:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
self.output_top_logprobs_val
|
||||
) = self.output_top_logprobs_idx = None
|
||||
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
||||
self.output_token_ids_logprobs_idx
|
||||
) = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
@@ -345,6 +377,13 @@ class Req:
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
||||
self.to_abort_message: str = "Unknown error"
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
@@ -422,7 +461,9 @@ class Req:
|
||||
return
|
||||
|
||||
if self.to_abort:
|
||||
self.finished_reason = FINISH_ABORT()
|
||||
self.finished_reason = FINISH_ABORT(
|
||||
message=self.to_abort_message,
|
||||
)
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
@@ -517,6 +558,8 @@ class Req:
|
||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
||||
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
@@ -527,16 +570,19 @@ class Req:
|
||||
self.last_node = None
|
||||
self.extend_input_len = 0
|
||||
self.is_retracted = True
|
||||
self.input_token_logprobs = None
|
||||
self.temp_input_top_logprobs_val = None
|
||||
self.temp_input_top_logprobs_idx = None
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
|
||||
# For incremental logprobs
|
||||
# TODO: Fix the `logprob_start_len`
|
||||
self.last_update_decode_tokens = 0
|
||||
self.logprob_start_len = 10**9
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"rid(n={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
||||
f"Req(rid={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
||||
)
|
||||
|
||||
|
||||
@@ -576,11 +622,13 @@ class ScheduleBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
@@ -588,6 +636,8 @@ class ScheduleBatch:
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
extend_logprob_start_lens: List[int] = None
|
||||
# It comes empty list if logprob is not required.
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
@@ -606,7 +656,7 @@ class ScheduleBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
@@ -653,8 +703,10 @@ class ScheduleBatch:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"Out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`."
|
||||
"alloc_req_slots runs out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`. "
|
||||
f"{self.req_to_token_pool.available_size()=}, "
|
||||
f"{num_reqs=}, "
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
@@ -765,6 +817,7 @@ class ScheduleBatch:
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
input_embeds = []
|
||||
extend_input_logprob_token_ids = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
@@ -783,22 +836,64 @@ class ScheduleBatch:
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
if req.return_logprob:
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
||||
)
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
req.extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len,
|
||||
req.extend_input_len,
|
||||
req.seqlen - 1,
|
||||
)
|
||||
else:
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
if self.return_logprob:
|
||||
# Find input logprob token ids.
|
||||
# First, find a global index within origin_input_ids and slide it by 1
|
||||
# to compute input logprobs. It is because you need the next token
|
||||
# to compute input logprobs. E.g., (chunk size 2)
|
||||
#
|
||||
# input_logprobs = [1, 2, 3, 4]
|
||||
# fill_ids = [1, 2]
|
||||
# extend_input_logprob_token_id = [2, 3]
|
||||
#
|
||||
# Note that it can also overflow. In this case, we pad it with 0.
|
||||
# input_logprobs = [1, 2, 3, 4]
|
||||
# fill_ids = [3, 4]
|
||||
# extend_input_logprob_token_id = [4, 0]
|
||||
global_start_idx, global_end_idx = (
|
||||
len(req.prefix_indices),
|
||||
len(req.fill_ids),
|
||||
)
|
||||
# Apply logprob_start_len
|
||||
if global_start_idx < req.logprob_start_len:
|
||||
global_start_idx = req.logprob_start_len
|
||||
|
||||
logprob_token_ids = req.origin_input_ids[
|
||||
global_start_idx + 1 : global_end_idx + 1
|
||||
]
|
||||
extend_input_logprob_token_ids.extend(logprob_token_ids)
|
||||
|
||||
# We will need req.extend_input_len - req.extend_logprob_start_len number of
|
||||
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
|
||||
extend_input_logprob_token_ids.extend(
|
||||
[0]
|
||||
* (
|
||||
req.extend_input_len
|
||||
- req.extend_logprob_start_len
|
||||
- len(logprob_token_ids)
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_logprob:
|
||||
extend_input_logprob_token_ids = torch.tensor(
|
||||
extend_input_logprob_token_ids
|
||||
)
|
||||
else:
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
@@ -821,10 +916,12 @@ class ScheduleBatch:
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
@@ -860,7 +957,6 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -905,25 +1001,43 @@ class ScheduleBatch:
|
||||
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
def retract_decode(self, server_args: ServerArgs):
|
||||
"""Retract the decoding requests when there is not enough memory."""
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
# For spec decoding, filter_batch API can only filter
|
||||
# requests from the back, so we can only retract from the back.
|
||||
# TODO(sang): Clean up finish path and support better retract
|
||||
# policy.
|
||||
if not server_args.speculative_algorithm:
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
headroom_for_spec_decode += (
|
||||
num_reqs
|
||||
* server_args.speculative_eagle_topk
|
||||
* server_args.speculative_num_steps
|
||||
+ num_reqs * server_args.speculative_num_draft_tokens
|
||||
)
|
||||
return (
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
@@ -1048,17 +1162,40 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
if self.spec_algorithm.is_eagle():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
return
|
||||
|
||||
if self.sampling_info.penalizer_orchestrator.is_required:
|
||||
if self.enable_overlap:
|
||||
# TODO: this can be slow, optimize this.
|
||||
delayed_output_ids = torch.tensor(
|
||||
[
|
||||
(
|
||||
req.output_ids[-1]
|
||||
if len(req.output_ids)
|
||||
else req.origin_input_ids[-1]
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
delayed_output_ids
|
||||
)
|
||||
else:
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
self.output_ids.to(torch.int64)
|
||||
)
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
self.output_ids = None
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
@@ -1086,14 +1223,15 @@ class ScheduleBatch:
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
being_chunked_req: Optional[Req] = None,
|
||||
chunked_req_to_exclude: Optional[Req] = None,
|
||||
keep_indices: Optional[List[int]] = None,
|
||||
):
|
||||
if keep_indices is None:
|
||||
keep_indices = [
|
||||
i
|
||||
for i in range(len(self.reqs))
|
||||
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
||||
if not self.reqs[i].finished()
|
||||
and self.reqs[i] is not chunked_req_to_exclude
|
||||
]
|
||||
|
||||
if keep_indices is None or len(keep_indices) == 0:
|
||||
@@ -1105,31 +1243,34 @@ class ScheduleBatch:
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||
self.encoder_lens = self.encoder_lens[keep_indices_device]
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.output_ids = self.output_ids[new_indices]
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
||||
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
|
||||
else:
|
||||
self.top_logprobs_nums = None
|
||||
self.token_ids_logprobs = None
|
||||
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
self.has_grammar = any(req.grammar for req in self.reqs)
|
||||
|
||||
self.sampling_info.filter_batch(keep_indices, new_indices)
|
||||
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||
if self.spec_info:
|
||||
self.spec_info.filter_batch(new_indices)
|
||||
self.spec_info.filter_batch(keep_indices_device)
|
||||
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
@@ -1152,10 +1293,13 @@ class ScheduleBatch:
|
||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||
if self.return_logprob and other.return_logprob:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
||||
elif self.return_logprob:
|
||||
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||
self.token_ids_logprobs.extend([None] * len(other.reqs))
|
||||
elif other.return_logprob:
|
||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.return_logprob |= other.return_logprob
|
||||
@@ -1192,7 +1336,9 @@ class ScheduleBatch:
|
||||
seq_lens_sum=self.seq_lens_sum,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
token_ids_logprobs=self.token_ids_logprobs,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
@@ -1219,6 +1365,7 @@ class ScheduleBatch:
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
|
||||
# For logprob
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
token_ids_logprobs: Optional[List[List[int]]]
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For extend
|
||||
@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
extend_prefix_lens: Optional[List[int]]
|
||||
extend_logprob_start_lens: Optional[List[int]]
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user