Support incremental streaming of logprob/token_ids between scheduler and detokenizer (#6225)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -41,7 +41,7 @@ class BaseGrammarObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_terminated(self):
|
||||
raise NotImplementedError()
|
||||
return False
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalDecodeReq,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
)
|
||||
@@ -60,6 +61,8 @@ class DecodeStatus:
|
||||
decode_ids: List[int]
|
||||
surr_offset: int
|
||||
read_offset: int
|
||||
# Offset that's sent to tokenizer for incremental update.
|
||||
sent_offset: int = 0
|
||||
|
||||
|
||||
class DetokenizerManager:
|
||||
@@ -151,7 +154,7 @@ class DetokenizerManager:
|
||||
self.decode_status[rid] = s
|
||||
else:
|
||||
s = self.decode_status[rid]
|
||||
s.decode_ids = recv_obj.decode_ids[i]
|
||||
s.decode_ids.extend(recv_obj.decode_ids[i])
|
||||
|
||||
read_ids.append(
|
||||
self.trim_matched_stop(
|
||||
@@ -199,13 +202,15 @@ class DetokenizerManager:
|
||||
else:
|
||||
new_text = find_printable_text(new_text)
|
||||
|
||||
output_strs.append(
|
||||
self.trim_matched_stop(
|
||||
s.decoded_text + new_text,
|
||||
recv_obj.finished_reasons[i],
|
||||
recv_obj.no_stop_trim[i],
|
||||
)
|
||||
output_str = self.trim_matched_stop(
|
||||
s.decoded_text + new_text,
|
||||
recv_obj.finished_reasons[i],
|
||||
recv_obj.no_stop_trim[i],
|
||||
)
|
||||
# Incrementally send text.
|
||||
incremental_output = output_str[s.sent_offset :]
|
||||
s.sent_offset = len(output_str)
|
||||
output_strs.append(incremental_output)
|
||||
|
||||
return BatchStrOut(
|
||||
rids=recv_obj.rids,
|
||||
@@ -232,7 +237,15 @@ class DetokenizerManager:
|
||||
)
|
||||
|
||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||
raise NotImplementedError()
|
||||
outputs = self.tokenizer.detokenize(recv_obj)
|
||||
return BatchMultimodalOut(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
outputs=outputs,
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
)
|
||||
|
||||
|
||||
class LimitedCapacityDict(OrderedDict):
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -436,6 +437,7 @@ class Req:
|
||||
self.sampling_params = sampling_params
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.lora_path = lora_path
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
@@ -487,6 +489,13 @@ class Req:
|
||||
# For retraction
|
||||
self.is_retracted = False
|
||||
|
||||
# Incremental streamining
|
||||
self.send_token_offset: int = 0
|
||||
self.send_decode_id_offset: int = 0
|
||||
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
|
||||
# because the decode server does not have the first output token logprobs
|
||||
self.send_output_token_logprobs_offset: int = 0
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = return_logprob
|
||||
# Start index to compute logprob from.
|
||||
@@ -496,11 +505,9 @@ class Req:
|
||||
self.temp_scaled_logprobs = False
|
||||
self.top_p_normalized_logprobs = False
|
||||
|
||||
# Latency Breakdown
|
||||
self.queue_time_start = None
|
||||
self.queue_time_end = None
|
||||
|
||||
# Logprobs (return values)
|
||||
# True means the input logprob has been already sent to detokenizer.
|
||||
self.input_logprob_sent: bool = False
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
@@ -515,8 +522,10 @@ class Req:
|
||||
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
|
||||
if return_logprob:
|
||||
# shape: (bs, 1)
|
||||
self.output_token_logprobs_val = []
|
||||
self.output_token_logprobs_idx = []
|
||||
# shape: (bs, k)
|
||||
self.output_top_logprobs_val = []
|
||||
self.output_top_logprobs_idx = []
|
||||
self.output_token_ids_logprobs_val = []
|
||||
@@ -543,7 +552,12 @@ class Req:
|
||||
# The number of verification forward passes in the speculative decoding.
|
||||
# This is used to compute the average acceptance length per request.
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
# For metrics
|
||||
self.time_stats: TimeStats = TimeStats()
|
||||
self.has_log_time_stats: bool = False
|
||||
self.queue_time_start = None
|
||||
self.queue_time_end = None
|
||||
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = bootstrap_host
|
||||
@@ -562,8 +576,8 @@ class Req:
|
||||
# This is because kv is not ready in `process_prefill_chunk`.
|
||||
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
||||
self.tmp_end_idx: int = -1
|
||||
|
||||
self.metadata_buffer_index: int = -1
|
||||
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
@@ -656,6 +670,11 @@ class Req:
|
||||
)
|
||||
return
|
||||
|
||||
if self.grammar is not None:
|
||||
if self.grammar.is_terminated():
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
||||
return
|
||||
|
||||
last_token_id = self.output_ids[-1]
|
||||
|
||||
if not self.sampling_params.ignore_eos:
|
||||
@@ -713,6 +732,18 @@ class Req:
|
||||
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
||||
del self.kv_cache_cpu
|
||||
|
||||
def log_time_stats(self):
|
||||
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
|
||||
if self.has_log_time_stats is True:
|
||||
return
|
||||
|
||||
if self.bootstrap_room is not None:
|
||||
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
||||
else:
|
||||
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
||||
logger.info(f"{prefix}: {self.time_stats}")
|
||||
self.has_log_time_stats = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Req(rid={self.rid}, "
|
||||
|
||||
@@ -530,10 +530,6 @@ class Scheduler(
|
||||
)
|
||||
|
||||
def init_metrics(self):
|
||||
# The largest prefill length of a single request
|
||||
self._largest_prefill_len: int = 0
|
||||
# The largest context length (prefill + generation) of a single request
|
||||
self._largest_prefill_decode_len: int = 0
|
||||
self.last_gen_throughput: float = 0.0
|
||||
self.last_input_throughput: float = 0.0
|
||||
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
||||
@@ -1122,9 +1118,6 @@ class Scheduler(
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
self._largest_prefill_len = max(
|
||||
self._largest_prefill_len, adder.log_input_tokens
|
||||
)
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
f = (
|
||||
@@ -1601,14 +1594,9 @@ class Scheduler(
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
if self.return_health_check_ct:
|
||||
# Return some signal for the health check.
|
||||
@@ -1776,6 +1764,13 @@ class Scheduler(
|
||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
||||
if batch.next_batch_sampling_info:
|
||||
if batch.next_batch_sampling_info.grammars is not None:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def watchdog_thread(self):
|
||||
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
||||
self.watchdog_last_forward_ct = 0
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
||||
@@ -15,6 +18,8 @@ if TYPE_CHECKING:
|
||||
Scheduler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
||||
|
||||
|
||||
@@ -83,6 +88,7 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
req.time_stats.completion_time = time.time()
|
||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||
# This updates radix so others can match
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
@@ -149,10 +155,7 @@ class SchedulerOutputProcessorMixin:
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings, bid = result.embeddings, result.bid
|
||||
@@ -233,6 +236,7 @@ class SchedulerOutputProcessorMixin:
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
req.time_stats.completion_time = time.time()
|
||||
|
||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||
# speculative worker handles logprob in speculative decoding
|
||||
@@ -262,13 +266,8 @@ class SchedulerOutputProcessorMixin:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool_allocator.free_group_end()
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
@@ -530,16 +529,27 @@ class SchedulerOutputProcessorMixin:
|
||||
)
|
||||
|
||||
if should_output:
|
||||
send_token_offset = req.send_token_offset
|
||||
send_output_token_logprobs_offset = (
|
||||
req.send_output_token_logprobs_offset
|
||||
)
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||
decode_ids_list.append(decode_ids)
|
||||
|
||||
if self.model_config.is_multimodal_gen:
|
||||
decode_ids_list.append(decode_ids)
|
||||
else:
|
||||
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
|
||||
|
||||
req.send_decode_id_offset = len(decode_ids)
|
||||
read_offsets.append(read_offset)
|
||||
if self.skip_tokenizer_init:
|
||||
output_ids.append(req.output_ids)
|
||||
output_ids.append(req.output_ids[send_token_offset:])
|
||||
req.send_token_offset = len(req.output_ids)
|
||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||
spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
@@ -553,36 +563,90 @@ class SchedulerOutputProcessorMixin:
|
||||
spec_verify_ct.append(req.spec_verify_ct)
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
||||
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
||||
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
input_token_ids_logprobs_val.append(
|
||||
req.input_token_ids_logprobs_val
|
||||
)
|
||||
input_token_ids_logprobs_idx.append(
|
||||
req.input_token_ids_logprobs_idx
|
||||
)
|
||||
output_token_ids_logprobs_val.append(
|
||||
req.output_token_ids_logprobs_val
|
||||
)
|
||||
output_token_ids_logprobs_idx.append(
|
||||
req.output_token_ids_logprobs_idx
|
||||
)
|
||||
if (
|
||||
req.return_logprob
|
||||
and not req.input_logprob_sent
|
||||
# Decode server does not send input logprobs
|
||||
and self.disaggregation_mode != DisaggregationMode.DECODE
|
||||
):
|
||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||
input_token_ids_logprobs_val.append(
|
||||
req.input_token_ids_logprobs_val
|
||||
)
|
||||
input_token_ids_logprobs_idx.append(
|
||||
req.input_token_ids_logprobs_idx
|
||||
)
|
||||
req.input_logprob_sent = True
|
||||
else:
|
||||
input_token_logprobs_val.append([])
|
||||
input_token_logprobs_idx.append([])
|
||||
input_top_logprobs_val.append([])
|
||||
input_top_logprobs_idx.append([])
|
||||
input_token_ids_logprobs_val.append([])
|
||||
input_token_ids_logprobs_idx.append([])
|
||||
|
||||
if req.return_logprob:
|
||||
output_token_logprobs_val.append(
|
||||
req.output_token_logprobs_val[
|
||||
send_output_token_logprobs_offset:
|
||||
]
|
||||
)
|
||||
output_token_logprobs_idx.append(
|
||||
req.output_token_logprobs_idx[
|
||||
send_output_token_logprobs_offset:
|
||||
]
|
||||
)
|
||||
output_top_logprobs_val.append(
|
||||
req.output_top_logprobs_val[
|
||||
send_output_token_logprobs_offset:
|
||||
]
|
||||
)
|
||||
output_top_logprobs_idx.append(
|
||||
req.output_top_logprobs_idx[
|
||||
send_output_token_logprobs_offset:
|
||||
]
|
||||
)
|
||||
output_token_ids_logprobs_val.append(
|
||||
req.output_token_ids_logprobs_val[
|
||||
send_output_token_logprobs_offset:
|
||||
]
|
||||
)
|
||||
output_token_ids_logprobs_idx.append(
|
||||
req.output_token_ids_logprobs_idx[
|
||||
send_output_token_logprobs_offset:
|
||||
]
|
||||
)
|
||||
req.send_output_token_logprobs_offset = len(
|
||||
req.output_token_logprobs_val
|
||||
)
|
||||
else:
|
||||
output_token_logprobs_val.append([])
|
||||
output_token_logprobs_idx.append([])
|
||||
output_top_logprobs_val.append([])
|
||||
output_top_logprobs_idx.append([])
|
||||
output_token_ids_logprobs_val.append([])
|
||||
output_token_ids_logprobs_idx.append([])
|
||||
|
||||
if req.return_hidden_states:
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = []
|
||||
output_hidden_states.append(req.hidden_states)
|
||||
|
||||
if (
|
||||
req.finished()
|
||||
and self.tp_rank == 0
|
||||
and self.server_args.enable_request_time_stats_logging
|
||||
):
|
||||
req.log_time_stats()
|
||||
|
||||
# Send to detokenizer
|
||||
if rids:
|
||||
if self.model_config.is_multimodal_gen:
|
||||
return
|
||||
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchTokenIDOut(
|
||||
rids,
|
||||
|
||||
@@ -125,10 +125,10 @@ logger = logging.getLogger(__name__)
|
||||
class ReqState:
|
||||
"""Store the state a request."""
|
||||
|
||||
out_list: List
|
||||
out_list: List[Dict[Any, Any]]
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
obj: Any
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||
|
||||
# For metrics
|
||||
created_time: float
|
||||
@@ -139,6 +139,21 @@ class ReqState:
|
||||
|
||||
# For streaming output
|
||||
last_output_offset: int = 0
|
||||
# For incremental state update.
|
||||
text: str = ""
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
||||
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
@@ -1065,9 +1080,11 @@ class TokenizerManager:
|
||||
if getattr(state.obj, "return_logprob", False):
|
||||
self.convert_logprob_style(
|
||||
meta_info,
|
||||
state,
|
||||
state.obj.top_logprobs_num,
|
||||
state.obj.token_ids_logprob,
|
||||
state.obj.return_text_in_logprobs,
|
||||
state.obj.return_text_in_logprobs
|
||||
and not self.server_args.skip_tokenizer_init,
|
||||
recv_obj,
|
||||
i,
|
||||
)
|
||||
@@ -1084,18 +1101,19 @@ class TokenizerManager:
|
||||
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
state.text += recv_obj.output_strs[i]
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"text": state.text,
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
if self.server_args.stream_output and state.obj.stream:
|
||||
output_token_ids = recv_obj.output_ids[i][
|
||||
state.last_output_offset :
|
||||
]
|
||||
state.last_output_offset = len(recv_obj.output_ids[i])
|
||||
state.output_ids.extend(recv_obj.output_ids[i])
|
||||
output_token_ids = state.output_ids[state.last_output_offset :]
|
||||
state.last_output_offset = len(state.output_ids)
|
||||
else:
|
||||
output_token_ids = recv_obj.output_ids[i]
|
||||
state.output_ids.extend(recv_obj.output_ids[i])
|
||||
output_token_ids = state.output_ids
|
||||
|
||||
out_dict = {
|
||||
"output_ids": output_token_ids,
|
||||
@@ -1130,45 +1148,85 @@ class TokenizerManager:
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
meta_info: dict,
|
||||
state: ReqState,
|
||||
top_logprobs_num: int,
|
||||
token_ids_logprob: List[int],
|
||||
return_text_in_logprobs: bool,
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj_index: int,
|
||||
):
|
||||
if len(recv_obj.input_token_logprobs_val) > 0:
|
||||
state.input_token_logprobs_val.extend(
|
||||
recv_obj.input_token_logprobs_val[recv_obj_index]
|
||||
)
|
||||
state.input_token_logprobs_idx.extend(
|
||||
recv_obj.input_token_logprobs_idx[recv_obj_index]
|
||||
)
|
||||
state.output_token_logprobs_val.extend(
|
||||
recv_obj.output_token_logprobs_val[recv_obj_index]
|
||||
)
|
||||
state.output_token_logprobs_idx.extend(
|
||||
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
||||
)
|
||||
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
recv_obj.input_token_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
||||
state.input_token_logprobs_val,
|
||||
state.input_token_logprobs_idx,
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
recv_obj.output_token_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
||||
state.output_token_logprobs_val,
|
||||
state.output_token_logprobs_idx,
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if top_logprobs_num > 0:
|
||||
if len(recv_obj.input_top_logprobs_val) > 0:
|
||||
state.input_top_logprobs_val.extend(
|
||||
recv_obj.input_top_logprobs_val[recv_obj_index]
|
||||
)
|
||||
state.input_top_logprobs_idx.extend(
|
||||
recv_obj.input_top_logprobs_idx[recv_obj_index]
|
||||
)
|
||||
state.output_top_logprobs_val.extend(
|
||||
recv_obj.output_top_logprobs_val[recv_obj_index]
|
||||
)
|
||||
state.output_top_logprobs_idx.extend(
|
||||
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
||||
)
|
||||
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.input_top_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
||||
state.input_top_logprobs_val,
|
||||
state.input_top_logprobs_idx,
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.output_top_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
||||
state.output_top_logprobs_val,
|
||||
state.output_top_logprobs_idx,
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if token_ids_logprob is not None:
|
||||
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
||||
state.input_token_ids_logprobs_val.extend(
|
||||
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
|
||||
)
|
||||
state.input_token_ids_logprobs_idx.extend(
|
||||
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
|
||||
)
|
||||
state.output_token_ids_logprobs_val.extend(
|
||||
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
|
||||
)
|
||||
state.output_token_ids_logprobs_idx.extend(
|
||||
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
||||
)
|
||||
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
|
||||
state.input_token_ids_logprobs_val,
|
||||
state.input_token_ids_logprobs_idx,
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_token_ids_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
|
||||
state.output_token_ids_logprobs_val,
|
||||
state.output_token_ids_logprobs_idx,
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -127,10 +127,12 @@ class TpModelWorkerClient:
|
||||
batch_lists = [None] * 2
|
||||
|
||||
while True:
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
||||
if not model_worker_batch:
|
||||
break
|
||||
|
||||
sync_event.wait()
|
||||
|
||||
# Keep a reference of model_worker_batch by storing it into a list.
|
||||
# Otherwise, the tensor members of model_worker_batch will be released
|
||||
# by pytorch and cause CUDA illegal memory access errors.
|
||||
@@ -214,10 +216,11 @@ class TpModelWorkerClient:
|
||||
)
|
||||
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
self.scheduler_stream.synchronize()
|
||||
sync_event = torch.get_device_module(self.device).Event()
|
||||
sync_event.record(self.scheduler_stream)
|
||||
|
||||
# Push a new batch to the queue
|
||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
||||
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
|
||||
@@ -307,5 +307,5 @@ class SamplingBatchInfo:
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy |= other.is_all_greedy
|
||||
self.is_all_greedy &= other.is_all_greedy
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
@@ -98,6 +98,7 @@ class ServerArgs:
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
decode_log_interval: int = 40
|
||||
enable_request_time_stats_logging: bool = False
|
||||
|
||||
# API related
|
||||
api_key: Optional[str] = None
|
||||
@@ -785,6 +786,12 @@ class ServerArgs:
|
||||
default=ServerArgs.decode_log_interval,
|
||||
help="The log interval of decode batch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-request-time-stats-logging",
|
||||
action="store_true",
|
||||
default=ServerArgs.enable_request_time_stats_logging,
|
||||
help="Enable per request time stats logging",
|
||||
)
|
||||
|
||||
# API related
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user