724 lines
31 KiB
Python
724 lines
31 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
|
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
|
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.scheduler import (
|
|
EmbeddingBatchResult,
|
|
GenerationBatchResult,
|
|
ScheduleBatch,
|
|
Scheduler,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
|
|
|
|
|
class SchedulerOutputProcessorMixin:
|
|
"""
|
|
This class implements the output processing logic for Scheduler.
|
|
We put them into a separate file to make the `scheduler.py` shorter.
|
|
"""
|
|
|
|
def process_batch_result_prefill(
|
|
self: Scheduler,
|
|
batch: ScheduleBatch,
|
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
|
launch_done: Optional[threading.Event] = None,
|
|
):
|
|
skip_stream_req = None
|
|
|
|
if self.is_generation:
|
|
(
|
|
logits_output,
|
|
next_token_ids,
|
|
extend_input_len_per_req,
|
|
extend_logprob_start_len_per_req,
|
|
) = (
|
|
result.logits_output,
|
|
result.next_token_ids,
|
|
result.extend_input_len_per_req,
|
|
result.extend_logprob_start_len_per_req,
|
|
)
|
|
|
|
if self.enable_overlap:
|
|
logits_output, next_token_ids, _ = (
|
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
|
)
|
|
else:
|
|
# Move next_token_ids and logprobs to cpu
|
|
next_token_ids = next_token_ids.tolist()
|
|
if batch.return_logprob:
|
|
if logits_output.next_token_logprobs is not None:
|
|
logits_output.next_token_logprobs = (
|
|
logits_output.next_token_logprobs.tolist()
|
|
)
|
|
if logits_output.input_token_logprobs is not None:
|
|
logits_output.input_token_logprobs = tuple(
|
|
logits_output.input_token_logprobs.tolist()
|
|
)
|
|
|
|
hidden_state_offset = 0
|
|
|
|
# Check finish conditions
|
|
logprob_pt = 0
|
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
|
if req.is_retracted:
|
|
continue
|
|
|
|
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
|
# Free the one delayed token for the mixed decode batch
|
|
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
|
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
|
continue
|
|
|
|
if req.is_chunked <= 0:
|
|
# req output_ids are set here
|
|
req.output_ids.append(next_token_id)
|
|
req.check_finished()
|
|
|
|
if req.finished():
|
|
self.tree_cache.cache_finished_req(req)
|
|
req.time_stats.completion_time = time.time()
|
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
|
# This updates radix so others can match
|
|
self.tree_cache.cache_unfinished_req(req)
|
|
|
|
if req.return_logprob:
|
|
assert extend_logprob_start_len_per_req is not None
|
|
assert extend_input_len_per_req is not None
|
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
|
extend_input_len = extend_input_len_per_req[i]
|
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
|
self.add_logprob_return_values(
|
|
i,
|
|
req,
|
|
logprob_pt,
|
|
next_token_ids,
|
|
num_input_logprobs,
|
|
logits_output,
|
|
)
|
|
logprob_pt += num_input_logprobs
|
|
|
|
if (
|
|
req.return_hidden_states
|
|
and logits_output.hidden_states is not None
|
|
):
|
|
req.hidden_states.append(
|
|
logits_output.hidden_states[
|
|
hidden_state_offset : (
|
|
hidden_state_offset := hidden_state_offset
|
|
+ len(req.origin_input_ids)
|
|
)
|
|
]
|
|
.cpu()
|
|
.clone()
|
|
.tolist()
|
|
)
|
|
|
|
if req.grammar is not None:
|
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
|
try:
|
|
req.grammar.accept_token(next_token_id)
|
|
except ValueError as e:
|
|
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
|
# This can happen if the grammar is not set correctly or the token is invalid.
|
|
logger.error(
|
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
|
)
|
|
self.abort_request(AbortReq(req.rid))
|
|
req.grammar.finished = req.finished()
|
|
else:
|
|
# being chunked reqs' prefill is not finished
|
|
req.is_chunked -= 1
|
|
# There is only at most one request being currently chunked.
|
|
# Because this request does not finish prefill,
|
|
# we don't want to stream the request currently being chunked.
|
|
skip_stream_req = req
|
|
|
|
# Incrementally update input logprobs.
|
|
if req.return_logprob:
|
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
|
extend_input_len = extend_input_len_per_req[i]
|
|
if extend_logprob_start_len < extend_input_len:
|
|
# Update input logprobs.
|
|
num_input_logprobs = (
|
|
extend_input_len - extend_logprob_start_len
|
|
)
|
|
self.add_input_logprob_return_values(
|
|
i,
|
|
req,
|
|
logits_output,
|
|
logprob_pt,
|
|
num_input_logprobs,
|
|
last_prefill_chunk=False,
|
|
)
|
|
logprob_pt += num_input_logprobs
|
|
|
|
self.set_next_batch_sampling_info_done(batch)
|
|
|
|
else: # embedding or reward model
|
|
embeddings, bid = result.embeddings, result.bid
|
|
embeddings = embeddings.tolist()
|
|
|
|
# Check finish conditions
|
|
for i, req in enumerate(batch.reqs):
|
|
if req.is_retracted:
|
|
continue
|
|
|
|
req.embedding = embeddings[i]
|
|
if req.is_chunked <= 0:
|
|
# Dummy output token for embedding models
|
|
req.output_ids.append(0)
|
|
req.check_finished()
|
|
|
|
if req.finished():
|
|
self.tree_cache.cache_finished_req(req)
|
|
else:
|
|
self.tree_cache.cache_unfinished_req(req)
|
|
else:
|
|
# being chunked reqs' prefill is not finished
|
|
req.is_chunked -= 1
|
|
|
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
|
|
|
def process_batch_result_decode(
|
|
self: Scheduler,
|
|
batch: ScheduleBatch,
|
|
result: GenerationBatchResult,
|
|
launch_done: Optional[threading.Event] = None,
|
|
):
|
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
|
result.logits_output,
|
|
result.next_token_ids,
|
|
result.can_run_cuda_graph,
|
|
)
|
|
self.num_generated_tokens += len(batch.reqs)
|
|
|
|
if self.enable_overlap:
|
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
|
)
|
|
next_token_logprobs = logits_output.next_token_logprobs
|
|
elif batch.spec_algorithm.is_none():
|
|
# spec decoding handles output logprobs inside verify process.
|
|
next_token_ids = next_token_ids.tolist()
|
|
if batch.return_logprob:
|
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
|
|
|
self.token_to_kv_pool_allocator.free_group_begin()
|
|
|
|
# Check finish condition
|
|
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
|
# We should ignore using next_token_ids for spec decoding cases.
|
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
|
if req.is_retracted:
|
|
continue
|
|
|
|
if self.enable_overlap and req.finished():
|
|
# Free the one extra delayed token
|
|
if self.page_size == 1:
|
|
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
|
else:
|
|
# Only free when the extra token is in a new page
|
|
if (
|
|
len(req.origin_input_ids) + len(req.output_ids) - 1
|
|
) % self.page_size == 0:
|
|
self.token_to_kv_pool_allocator.free(
|
|
batch.out_cache_loc[i : i + 1]
|
|
)
|
|
continue
|
|
|
|
if batch.spec_algorithm.is_none():
|
|
# speculative worker will solve the output_ids in speculative decoding
|
|
req.output_ids.append(next_token_id)
|
|
|
|
req.check_finished()
|
|
if req.finished():
|
|
self.tree_cache.cache_finished_req(req)
|
|
req.time_stats.completion_time = time.time()
|
|
|
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
|
# speculative worker handles logprob in speculative decoding
|
|
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
|
req.output_token_logprobs_idx.append(next_token_id)
|
|
if req.top_logprobs_num > 0:
|
|
req.output_top_logprobs_val.append(
|
|
logits_output.next_token_top_logprobs_val[i]
|
|
)
|
|
req.output_top_logprobs_idx.append(
|
|
logits_output.next_token_top_logprobs_idx[i]
|
|
)
|
|
if req.token_ids_logprob is not None:
|
|
req.output_token_ids_logprobs_val.append(
|
|
logits_output.next_token_token_ids_logprobs_val[i]
|
|
)
|
|
req.output_token_ids_logprobs_idx.append(
|
|
logits_output.next_token_token_ids_logprobs_idx[i]
|
|
)
|
|
|
|
if req.return_hidden_states and logits_output.hidden_states is not None:
|
|
req.hidden_states.append(
|
|
logits_output.hidden_states[i].cpu().clone().tolist()
|
|
)
|
|
|
|
if req.grammar is not None and batch.spec_algorithm.is_none():
|
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
|
try:
|
|
req.grammar.accept_token(next_token_id)
|
|
except ValueError as e:
|
|
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
|
# This can happen if the grammar is not set correctly or the token is invalid.
|
|
logger.error(
|
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
|
)
|
|
self.abort_request(AbortReq(req.rid))
|
|
req.grammar.finished = req.finished()
|
|
|
|
self.set_next_batch_sampling_info_done(batch)
|
|
self.stream_output(batch.reqs, batch.return_logprob)
|
|
self.token_to_kv_pool_allocator.free_group_end()
|
|
|
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
|
if (
|
|
self.current_scheduler_metrics_enabled()
|
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
|
):
|
|
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
|
|
|
def add_input_logprob_return_values(
|
|
self: Scheduler,
|
|
i: int,
|
|
req: Req,
|
|
output: LogitsProcessorOutput,
|
|
logprob_pt: int,
|
|
num_input_logprobs: int,
|
|
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
|
):
|
|
"""Incrementally add input logprobs to `req`.
|
|
|
|
Args:
|
|
i: The request index in a batch.
|
|
req: The request. Input logprobs inside req are modified as a
|
|
consequence of the API
|
|
fill_ids: The prefill ids processed.
|
|
output: Logit processor output that's used to compute input logprobs
|
|
last_prefill_chunk: True if it is the last prefill (when chunked).
|
|
Some of input logprob operation should only happen at the last
|
|
prefill (e.g., computing input token logprobs).
|
|
"""
|
|
assert output.input_token_logprobs is not None
|
|
if req.input_token_logprobs is None:
|
|
req.input_token_logprobs = []
|
|
if req.temp_input_top_logprobs_val is None:
|
|
req.temp_input_top_logprobs_val = []
|
|
if req.temp_input_top_logprobs_idx is None:
|
|
req.temp_input_top_logprobs_idx = []
|
|
if req.temp_input_token_ids_logprobs_val is None:
|
|
req.temp_input_token_ids_logprobs_val = []
|
|
if req.temp_input_token_ids_logprobs_idx is None:
|
|
req.temp_input_token_ids_logprobs_idx = []
|
|
|
|
if req.input_token_logprobs_val is not None:
|
|
# The input logprob has been already computed. It only happens
|
|
# upon retract.
|
|
if req.top_logprobs_num > 0:
|
|
assert req.input_token_logprobs_val is not None
|
|
return
|
|
|
|
# Important for the performance.
|
|
assert isinstance(output.input_token_logprobs, tuple)
|
|
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
|
input_token_logprobs = input_token_logprobs[
|
|
logprob_pt : logprob_pt + num_input_logprobs
|
|
]
|
|
req.input_token_logprobs.extend(input_token_logprobs)
|
|
|
|
if req.top_logprobs_num > 0:
|
|
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
|
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
|
|
|
if req.token_ids_logprob is not None:
|
|
req.temp_input_token_ids_logprobs_val.append(
|
|
output.input_token_ids_logprobs_val[i]
|
|
)
|
|
req.temp_input_token_ids_logprobs_idx.append(
|
|
output.input_token_ids_logprobs_idx[i]
|
|
)
|
|
|
|
if last_prefill_chunk:
|
|
input_token_logprobs = req.input_token_logprobs
|
|
req.input_token_logprobs = None
|
|
assert req.input_token_logprobs_val is None
|
|
assert req.input_token_logprobs_idx is None
|
|
assert req.input_top_logprobs_val is None
|
|
assert req.input_top_logprobs_idx is None
|
|
|
|
# Compute input_token_logprobs_val
|
|
# Always pad the first one with None.
|
|
req.input_token_logprobs_val = [None]
|
|
req.input_token_logprobs_val.extend(input_token_logprobs)
|
|
# The last input logprob is for sampling, so just pop it out.
|
|
req.input_token_logprobs_val.pop()
|
|
|
|
# Compute input_token_logprobs_idx
|
|
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
|
# Clip the padded hash values from image tokens.
|
|
# Otherwise, it will lead to detokenization errors.
|
|
input_token_logprobs_idx = [
|
|
x if x < self.model_config.vocab_size - 1 else 0
|
|
for x in input_token_logprobs_idx
|
|
]
|
|
req.input_token_logprobs_idx = input_token_logprobs_idx
|
|
|
|
if req.top_logprobs_num > 0:
|
|
req.input_top_logprobs_val = [None]
|
|
req.input_top_logprobs_idx = [None]
|
|
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
|
req.temp_input_token_ids_logprobs_idx
|
|
)
|
|
for val, idx in zip(
|
|
req.temp_input_top_logprobs_val,
|
|
req.temp_input_top_logprobs_idx,
|
|
strict=True,
|
|
):
|
|
req.input_top_logprobs_val.extend(val)
|
|
req.input_top_logprobs_idx.extend(idx)
|
|
|
|
# Last token is a sample token.
|
|
req.input_top_logprobs_val.pop()
|
|
req.input_top_logprobs_idx.pop()
|
|
req.temp_input_top_logprobs_idx = None
|
|
req.temp_input_top_logprobs_val = None
|
|
|
|
if req.token_ids_logprob is not None:
|
|
req.input_token_ids_logprobs_val = [None]
|
|
req.input_token_ids_logprobs_idx = [None]
|
|
|
|
for val, idx in zip(
|
|
req.temp_input_token_ids_logprobs_val,
|
|
req.temp_input_token_ids_logprobs_idx,
|
|
strict=True,
|
|
):
|
|
req.input_token_ids_logprobs_val.extend(val)
|
|
req.input_token_ids_logprobs_idx.extend(idx)
|
|
|
|
# Last token is a sample token.
|
|
req.input_token_ids_logprobs_val.pop()
|
|
req.input_token_ids_logprobs_idx.pop()
|
|
req.temp_input_token_ids_logprobs_idx = None
|
|
req.temp_input_token_ids_logprobs_val = None
|
|
|
|
if req.return_logprob:
|
|
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
|
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
|
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
|
if req.top_logprobs_num > 0:
|
|
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
|
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
|
if req.token_ids_logprob is not None:
|
|
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
|
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
|
|
|
def add_logprob_return_values(
|
|
self: Scheduler,
|
|
i: int,
|
|
req: Req,
|
|
pt: int,
|
|
next_token_ids: List[int],
|
|
num_input_logprobs: int,
|
|
output: LogitsProcessorOutput,
|
|
):
|
|
"""Attach logprobs to the return values."""
|
|
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
|
req.output_token_logprobs_idx.append(next_token_ids[i])
|
|
|
|
self.add_input_logprob_return_values(
|
|
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
|
)
|
|
|
|
if req.top_logprobs_num > 0:
|
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
|
|
|
if req.token_ids_logprob is not None:
|
|
req.output_token_ids_logprobs_val.append(
|
|
output.next_token_token_ids_logprobs_val[i]
|
|
)
|
|
req.output_token_ids_logprobs_idx.append(
|
|
output.next_token_token_ids_logprobs_idx[i]
|
|
)
|
|
|
|
return num_input_logprobs
|
|
|
|
def stream_output(
|
|
self: Scheduler,
|
|
reqs: List[Req],
|
|
return_logprob: bool,
|
|
skip_req: Optional[Req] = None,
|
|
):
|
|
"""Stream the output to detokenizer."""
|
|
if self.is_generation:
|
|
self.stream_output_generation(reqs, return_logprob, skip_req)
|
|
else: # embedding or reward model
|
|
self.stream_output_embedding(reqs)
|
|
|
|
def stream_output_generation(
|
|
self: Scheduler,
|
|
reqs: List[Req],
|
|
return_logprob: bool,
|
|
skip_req: Optional[Req] = None,
|
|
):
|
|
rids = []
|
|
finished_reasons: List[BaseFinishReason] = []
|
|
|
|
decoded_texts = []
|
|
decode_ids_list = []
|
|
read_offsets = []
|
|
output_ids = []
|
|
|
|
skip_special_tokens = []
|
|
spaces_between_special_tokens = []
|
|
no_stop_trim = []
|
|
prompt_tokens = []
|
|
completion_tokens = []
|
|
cached_tokens = []
|
|
spec_verify_ct = []
|
|
output_hidden_states = None
|
|
|
|
if return_logprob:
|
|
input_token_logprobs_val = []
|
|
input_token_logprobs_idx = []
|
|
output_token_logprobs_val = []
|
|
output_token_logprobs_idx = []
|
|
input_top_logprobs_val = []
|
|
input_top_logprobs_idx = []
|
|
output_top_logprobs_val = []
|
|
output_top_logprobs_idx = []
|
|
input_token_ids_logprobs_val = []
|
|
input_token_ids_logprobs_idx = []
|
|
output_token_ids_logprobs_val = []
|
|
output_token_ids_logprobs_idx = []
|
|
else:
|
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
|
output_token_logprobs_val
|
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
|
input_top_logprobs_idx
|
|
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
|
input_token_ids_logprobs_val
|
|
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
|
output_token_ids_logprobs_idx
|
|
) = None
|
|
|
|
for req in reqs:
|
|
if req is skip_req:
|
|
continue
|
|
|
|
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
|
if self.model_config.is_multimodal_gen and req.to_abort:
|
|
continue
|
|
|
|
if req.finished():
|
|
if req.finished_output:
|
|
# With the overlap schedule, a request will try to output twice and hit this line twice
|
|
# because of the one additional delayed token. This "continue" prevented the dummy output.
|
|
continue
|
|
req.finished_output = True
|
|
should_output = True
|
|
else:
|
|
if req.stream:
|
|
stream_interval = (
|
|
req.sampling_params.stream_interval or self.stream_interval
|
|
)
|
|
should_output = (
|
|
len(req.output_ids) % stream_interval == 1
|
|
if not self.model_config.is_multimodal_gen
|
|
and stream_interval > 1
|
|
else len(req.output_ids) % stream_interval == 0
|
|
)
|
|
else:
|
|
should_output = (
|
|
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
|
if not self.model_config.is_multimodal_gen
|
|
else False
|
|
)
|
|
|
|
if should_output:
|
|
send_token_offset = req.send_token_offset
|
|
send_output_token_logprobs_offset = (
|
|
req.send_output_token_logprobs_offset
|
|
)
|
|
rids.append(req.rid)
|
|
finished_reasons.append(
|
|
req.finished_reason.to_json() if req.finished_reason else None
|
|
)
|
|
decoded_texts.append(req.decoded_text)
|
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
|
|
|
if self.model_config.is_multimodal_gen:
|
|
decode_ids_list.append(decode_ids)
|
|
else:
|
|
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
|
|
|
|
req.send_decode_id_offset = len(decode_ids)
|
|
read_offsets.append(read_offset)
|
|
if self.skip_tokenizer_init:
|
|
output_ids.append(req.output_ids[send_token_offset:])
|
|
req.send_token_offset = len(req.output_ids)
|
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
|
spaces_between_special_tokens.append(
|
|
req.sampling_params.spaces_between_special_tokens
|
|
)
|
|
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
|
prompt_tokens.append(len(req.origin_input_ids))
|
|
completion_tokens.append(len(req.output_ids))
|
|
cached_tokens.append(req.cached_tokens)
|
|
|
|
if not self.spec_algorithm.is_none():
|
|
spec_verify_ct.append(req.spec_verify_ct)
|
|
|
|
if return_logprob:
|
|
if (
|
|
req.return_logprob
|
|
and not req.input_logprob_sent
|
|
# Decode server does not send input logprobs
|
|
and self.disaggregation_mode != DisaggregationMode.DECODE
|
|
):
|
|
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
|
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
|
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
|
input_token_ids_logprobs_val.append(
|
|
req.input_token_ids_logprobs_val
|
|
)
|
|
input_token_ids_logprobs_idx.append(
|
|
req.input_token_ids_logprobs_idx
|
|
)
|
|
req.input_logprob_sent = True
|
|
else:
|
|
input_token_logprobs_val.append([])
|
|
input_token_logprobs_idx.append([])
|
|
input_top_logprobs_val.append([])
|
|
input_top_logprobs_idx.append([])
|
|
input_token_ids_logprobs_val.append([])
|
|
input_token_ids_logprobs_idx.append([])
|
|
|
|
if req.return_logprob:
|
|
output_token_logprobs_val.append(
|
|
req.output_token_logprobs_val[
|
|
send_output_token_logprobs_offset:
|
|
]
|
|
)
|
|
output_token_logprobs_idx.append(
|
|
req.output_token_logprobs_idx[
|
|
send_output_token_logprobs_offset:
|
|
]
|
|
)
|
|
output_top_logprobs_val.append(
|
|
req.output_top_logprobs_val[
|
|
send_output_token_logprobs_offset:
|
|
]
|
|
)
|
|
output_top_logprobs_idx.append(
|
|
req.output_top_logprobs_idx[
|
|
send_output_token_logprobs_offset:
|
|
]
|
|
)
|
|
output_token_ids_logprobs_val.append(
|
|
req.output_token_ids_logprobs_val[
|
|
send_output_token_logprobs_offset:
|
|
]
|
|
)
|
|
output_token_ids_logprobs_idx.append(
|
|
req.output_token_ids_logprobs_idx[
|
|
send_output_token_logprobs_offset:
|
|
]
|
|
)
|
|
req.send_output_token_logprobs_offset = len(
|
|
req.output_token_logprobs_val
|
|
)
|
|
else:
|
|
output_token_logprobs_val.append([])
|
|
output_token_logprobs_idx.append([])
|
|
output_top_logprobs_val.append([])
|
|
output_top_logprobs_idx.append([])
|
|
output_token_ids_logprobs_val.append([])
|
|
output_token_ids_logprobs_idx.append([])
|
|
|
|
if req.return_hidden_states:
|
|
if output_hidden_states is None:
|
|
output_hidden_states = []
|
|
output_hidden_states.append(req.hidden_states)
|
|
|
|
if (
|
|
req.finished()
|
|
and self.tp_rank == 0
|
|
and self.server_args.enable_request_time_stats_logging
|
|
):
|
|
req.log_time_stats()
|
|
|
|
# Send to detokenizer
|
|
if rids:
|
|
if self.model_config.is_multimodal_gen:
|
|
return
|
|
|
|
self.send_to_detokenizer.send_pyobj(
|
|
BatchTokenIDOut(
|
|
rids,
|
|
finished_reasons,
|
|
decoded_texts,
|
|
decode_ids_list,
|
|
read_offsets,
|
|
output_ids,
|
|
skip_special_tokens,
|
|
spaces_between_special_tokens,
|
|
no_stop_trim,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
cached_tokens,
|
|
spec_verify_ct,
|
|
input_token_logprobs_val,
|
|
input_token_logprobs_idx,
|
|
output_token_logprobs_val,
|
|
output_token_logprobs_idx,
|
|
input_top_logprobs_val,
|
|
input_top_logprobs_idx,
|
|
output_top_logprobs_val,
|
|
output_top_logprobs_idx,
|
|
input_token_ids_logprobs_val,
|
|
input_token_ids_logprobs_idx,
|
|
output_token_ids_logprobs_val,
|
|
output_token_ids_logprobs_idx,
|
|
output_hidden_states,
|
|
)
|
|
)
|
|
|
|
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
|
rids = []
|
|
finished_reasons: List[BaseFinishReason] = []
|
|
|
|
embeddings = []
|
|
prompt_tokens = []
|
|
cached_tokens = []
|
|
for req in reqs:
|
|
if req.finished():
|
|
rids.append(req.rid)
|
|
finished_reasons.append(req.finished_reason.to_json())
|
|
embeddings.append(req.embedding)
|
|
prompt_tokens.append(len(req.origin_input_ids))
|
|
cached_tokens.append(req.cached_tokens)
|
|
self.send_to_detokenizer.send_pyobj(
|
|
BatchEmbeddingOut(
|
|
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
|
|
)
|
|
)
|