Support incremental streaming of logprob/token_ids between scheduler and detokenizer (#6225)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -41,7 +41,7 @@ class BaseGrammarObject:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def is_terminated(self):
|
def is_terminated(self):
|
||||||
raise NotImplementedError()
|
return False
|
||||||
|
|
||||||
def allocate_vocab_mask(
|
def allocate_vocab_mask(
|
||||||
self, vocab_size: int, batch_size: int, device
|
self, vocab_size: int, batch_size: int, device
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchMultimodalDecodeReq,
|
BatchMultimodalDecodeReq,
|
||||||
|
BatchMultimodalOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
)
|
)
|
||||||
@@ -60,6 +61,8 @@ class DecodeStatus:
|
|||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
surr_offset: int
|
surr_offset: int
|
||||||
read_offset: int
|
read_offset: int
|
||||||
|
# Offset that's sent to tokenizer for incremental update.
|
||||||
|
sent_offset: int = 0
|
||||||
|
|
||||||
|
|
||||||
class DetokenizerManager:
|
class DetokenizerManager:
|
||||||
@@ -151,7 +154,7 @@ class DetokenizerManager:
|
|||||||
self.decode_status[rid] = s
|
self.decode_status[rid] = s
|
||||||
else:
|
else:
|
||||||
s = self.decode_status[rid]
|
s = self.decode_status[rid]
|
||||||
s.decode_ids = recv_obj.decode_ids[i]
|
s.decode_ids.extend(recv_obj.decode_ids[i])
|
||||||
|
|
||||||
read_ids.append(
|
read_ids.append(
|
||||||
self.trim_matched_stop(
|
self.trim_matched_stop(
|
||||||
@@ -199,13 +202,15 @@ class DetokenizerManager:
|
|||||||
else:
|
else:
|
||||||
new_text = find_printable_text(new_text)
|
new_text = find_printable_text(new_text)
|
||||||
|
|
||||||
output_strs.append(
|
output_str = self.trim_matched_stop(
|
||||||
self.trim_matched_stop(
|
s.decoded_text + new_text,
|
||||||
s.decoded_text + new_text,
|
recv_obj.finished_reasons[i],
|
||||||
recv_obj.finished_reasons[i],
|
recv_obj.no_stop_trim[i],
|
||||||
recv_obj.no_stop_trim[i],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
# Incrementally send text.
|
||||||
|
incremental_output = output_str[s.sent_offset :]
|
||||||
|
s.sent_offset = len(output_str)
|
||||||
|
output_strs.append(incremental_output)
|
||||||
|
|
||||||
return BatchStrOut(
|
return BatchStrOut(
|
||||||
rids=recv_obj.rids,
|
rids=recv_obj.rids,
|
||||||
@@ -232,7 +237,15 @@ class DetokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||||
raise NotImplementedError()
|
outputs = self.tokenizer.detokenize(recv_obj)
|
||||||
|
return BatchMultimodalOut(
|
||||||
|
rids=recv_obj.rids,
|
||||||
|
finished_reasons=recv_obj.finished_reasons,
|
||||||
|
outputs=outputs,
|
||||||
|
prompt_tokens=recv_obj.prompt_tokens,
|
||||||
|
completion_tokens=recv_obj.completion_tokens,
|
||||||
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LimitedCapacityDict(OrderedDict):
|
class LimitedCapacityDict(OrderedDict):
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
|
from sglang.srt.metrics.collector import TimeStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -436,6 +437,7 @@ class Req:
|
|||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.custom_logit_processor = custom_logit_processor
|
self.custom_logit_processor = custom_logit_processor
|
||||||
self.return_hidden_states = return_hidden_states
|
self.return_hidden_states = return_hidden_states
|
||||||
|
self.lora_path = lora_path
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
self.req_pool_idx: Optional[int] = None
|
self.req_pool_idx: Optional[int] = None
|
||||||
@@ -487,6 +489,13 @@ class Req:
|
|||||||
# For retraction
|
# For retraction
|
||||||
self.is_retracted = False
|
self.is_retracted = False
|
||||||
|
|
||||||
|
# Incremental streamining
|
||||||
|
self.send_token_offset: int = 0
|
||||||
|
self.send_decode_id_offset: int = 0
|
||||||
|
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
|
||||||
|
# because the decode server does not have the first output token logprobs
|
||||||
|
self.send_output_token_logprobs_offset: int = 0
|
||||||
|
|
||||||
# Logprobs (arguments)
|
# Logprobs (arguments)
|
||||||
self.return_logprob = return_logprob
|
self.return_logprob = return_logprob
|
||||||
# Start index to compute logprob from.
|
# Start index to compute logprob from.
|
||||||
@@ -496,11 +505,9 @@ class Req:
|
|||||||
self.temp_scaled_logprobs = False
|
self.temp_scaled_logprobs = False
|
||||||
self.top_p_normalized_logprobs = False
|
self.top_p_normalized_logprobs = False
|
||||||
|
|
||||||
# Latency Breakdown
|
|
||||||
self.queue_time_start = None
|
|
||||||
self.queue_time_end = None
|
|
||||||
|
|
||||||
# Logprobs (return values)
|
# Logprobs (return values)
|
||||||
|
# True means the input logprob has been already sent to detokenizer.
|
||||||
|
self.input_logprob_sent: bool = False
|
||||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||||
@@ -515,8 +522,10 @@ class Req:
|
|||||||
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
|
# shape: (bs, 1)
|
||||||
self.output_token_logprobs_val = []
|
self.output_token_logprobs_val = []
|
||||||
self.output_token_logprobs_idx = []
|
self.output_token_logprobs_idx = []
|
||||||
|
# shape: (bs, k)
|
||||||
self.output_top_logprobs_val = []
|
self.output_top_logprobs_val = []
|
||||||
self.output_top_logprobs_idx = []
|
self.output_top_logprobs_idx = []
|
||||||
self.output_token_ids_logprobs_val = []
|
self.output_token_ids_logprobs_val = []
|
||||||
@@ -543,7 +552,12 @@ class Req:
|
|||||||
# The number of verification forward passes in the speculative decoding.
|
# The number of verification forward passes in the speculative decoding.
|
||||||
# This is used to compute the average acceptance length per request.
|
# This is used to compute the average acceptance length per request.
|
||||||
self.spec_verify_ct = 0
|
self.spec_verify_ct = 0
|
||||||
self.lora_path = lora_path
|
|
||||||
|
# For metrics
|
||||||
|
self.time_stats: TimeStats = TimeStats()
|
||||||
|
self.has_log_time_stats: bool = False
|
||||||
|
self.queue_time_start = None
|
||||||
|
self.queue_time_end = None
|
||||||
|
|
||||||
# For disaggregation
|
# For disaggregation
|
||||||
self.bootstrap_host: str = bootstrap_host
|
self.bootstrap_host: str = bootstrap_host
|
||||||
@@ -562,8 +576,8 @@ class Req:
|
|||||||
# This is because kv is not ready in `process_prefill_chunk`.
|
# This is because kv is not ready in `process_prefill_chunk`.
|
||||||
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
||||||
self.tmp_end_idx: int = -1
|
self.tmp_end_idx: int = -1
|
||||||
|
|
||||||
self.metadata_buffer_index: int = -1
|
self.metadata_buffer_index: int = -1
|
||||||
|
|
||||||
# The first output_id transferred from prefill instance.
|
# The first output_id transferred from prefill instance.
|
||||||
self.transferred_output_id: Optional[int] = None
|
self.transferred_output_id: Optional[int] = None
|
||||||
|
|
||||||
@@ -656,6 +670,11 @@ class Req:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.grammar is not None:
|
||||||
|
if self.grammar.is_terminated():
|
||||||
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
||||||
|
return
|
||||||
|
|
||||||
last_token_id = self.output_ids[-1]
|
last_token_id = self.output_ids[-1]
|
||||||
|
|
||||||
if not self.sampling_params.ignore_eos:
|
if not self.sampling_params.ignore_eos:
|
||||||
@@ -713,6 +732,18 @@ class Req:
|
|||||||
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
||||||
del self.kv_cache_cpu
|
del self.kv_cache_cpu
|
||||||
|
|
||||||
|
def log_time_stats(self):
|
||||||
|
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
|
||||||
|
if self.has_log_time_stats is True:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.bootstrap_room is not None:
|
||||||
|
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
||||||
|
else:
|
||||||
|
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
||||||
|
logger.info(f"{prefix}: {self.time_stats}")
|
||||||
|
self.has_log_time_stats = True
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"Req(rid={self.rid}, "
|
f"Req(rid={self.rid}, "
|
||||||
|
|||||||
@@ -530,10 +530,6 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_metrics(self):
|
def init_metrics(self):
|
||||||
# The largest prefill length of a single request
|
|
||||||
self._largest_prefill_len: int = 0
|
|
||||||
# The largest context length (prefill + generation) of a single request
|
|
||||||
self._largest_prefill_decode_len: int = 0
|
|
||||||
self.last_gen_throughput: float = 0.0
|
self.last_gen_throughput: float = 0.0
|
||||||
self.last_input_throughput: float = 0.0
|
self.last_input_throughput: float = 0.0
|
||||||
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
||||||
@@ -1122,9 +1118,6 @@ class Scheduler(
|
|||||||
self.token_to_kv_pool_allocator.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
self._largest_prefill_len = max(
|
|
||||||
self._largest_prefill_len, adder.log_input_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
num_new_seq = len(can_run_list)
|
num_new_seq = len(can_run_list)
|
||||||
f = (
|
f = (
|
||||||
@@ -1601,14 +1594,9 @@ class Scheduler(
|
|||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
if batch.next_batch_sampling_info:
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
|
||||||
self.current_stream.synchronize()
|
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
|
||||||
elif batch.forward_mode.is_dummy_first():
|
elif batch.forward_mode.is_dummy_first():
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
self.current_stream.synchronize()
|
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
|
||||||
|
|
||||||
if self.return_health_check_ct:
|
if self.return_health_check_ct:
|
||||||
# Return some signal for the health check.
|
# Return some signal for the health check.
|
||||||
@@ -1776,6 +1764,13 @@ class Scheduler(
|
|||||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
|
|
||||||
|
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
||||||
|
if batch.next_batch_sampling_info:
|
||||||
|
if batch.next_batch_sampling_info.grammars is not None:
|
||||||
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
|
self.current_stream.synchronize()
|
||||||
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
def watchdog_thread(self):
|
def watchdog_thread(self):
|
||||||
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
||||||
self.watchdog_last_forward_ct = 0
|
self.watchdog_last_forward_ct = 0
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
||||||
@@ -15,6 +18,8 @@ if TYPE_CHECKING:
|
|||||||
Scheduler,
|
Scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
||||||
|
|
||||||
|
|
||||||
@@ -83,6 +88,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
req.time_stats.completion_time = time.time()
|
||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
# This updates radix so others can match
|
# This updates radix so others can match
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
@@ -149,10 +155,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
)
|
)
|
||||||
logprob_pt += num_input_logprobs
|
logprob_pt += num_input_logprobs
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
|
||||||
self.current_stream.synchronize()
|
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings, bid = result.embeddings, result.bid
|
embeddings, bid = result.embeddings, result.bid
|
||||||
@@ -233,6 +236,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
req.time_stats.completion_time = time.time()
|
||||||
|
|
||||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||||
# speculative worker handles logprob in speculative decoding
|
# speculative worker handles logprob in speculative decoding
|
||||||
@@ -262,13 +266,8 @@ class SchedulerOutputProcessorMixin:
|
|||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
|
||||||
self.current_stream.synchronize()
|
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
|
||||||
|
|
||||||
self.stream_output(batch.reqs, batch.return_logprob)
|
self.stream_output(batch.reqs, batch.return_logprob)
|
||||||
|
|
||||||
self.token_to_kv_pool_allocator.free_group_end()
|
self.token_to_kv_pool_allocator.free_group_end()
|
||||||
|
|
||||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||||
@@ -530,16 +529,27 @@ class SchedulerOutputProcessorMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if should_output:
|
if should_output:
|
||||||
|
send_token_offset = req.send_token_offset
|
||||||
|
send_output_token_logprobs_offset = (
|
||||||
|
req.send_output_token_logprobs_offset
|
||||||
|
)
|
||||||
rids.append(req.rid)
|
rids.append(req.rid)
|
||||||
finished_reasons.append(
|
finished_reasons.append(
|
||||||
req.finished_reason.to_json() if req.finished_reason else None
|
req.finished_reason.to_json() if req.finished_reason else None
|
||||||
)
|
)
|
||||||
decoded_texts.append(req.decoded_text)
|
decoded_texts.append(req.decoded_text)
|
||||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||||
decode_ids_list.append(decode_ids)
|
|
||||||
|
if self.model_config.is_multimodal_gen:
|
||||||
|
decode_ids_list.append(decode_ids)
|
||||||
|
else:
|
||||||
|
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
|
||||||
|
|
||||||
|
req.send_decode_id_offset = len(decode_ids)
|
||||||
read_offsets.append(read_offset)
|
read_offsets.append(read_offset)
|
||||||
if self.skip_tokenizer_init:
|
if self.skip_tokenizer_init:
|
||||||
output_ids.append(req.output_ids)
|
output_ids.append(req.output_ids[send_token_offset:])
|
||||||
|
req.send_token_offset = len(req.output_ids)
|
||||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||||
spaces_between_special_tokens.append(
|
spaces_between_special_tokens.append(
|
||||||
req.sampling_params.spaces_between_special_tokens
|
req.sampling_params.spaces_between_special_tokens
|
||||||
@@ -553,36 +563,90 @@ class SchedulerOutputProcessorMixin:
|
|||||||
spec_verify_ct.append(req.spec_verify_ct)
|
spec_verify_ct.append(req.spec_verify_ct)
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
if (
|
||||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
req.return_logprob
|
||||||
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
and not req.input_logprob_sent
|
||||||
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
# Decode server does not send input logprobs
|
||||||
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
and self.disaggregation_mode != DisaggregationMode.DECODE
|
||||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
):
|
||||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||||
input_token_ids_logprobs_val.append(
|
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
||||||
req.input_token_ids_logprobs_val
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||||
)
|
input_token_ids_logprobs_val.append(
|
||||||
input_token_ids_logprobs_idx.append(
|
req.input_token_ids_logprobs_val
|
||||||
req.input_token_ids_logprobs_idx
|
)
|
||||||
)
|
input_token_ids_logprobs_idx.append(
|
||||||
output_token_ids_logprobs_val.append(
|
req.input_token_ids_logprobs_idx
|
||||||
req.output_token_ids_logprobs_val
|
)
|
||||||
)
|
req.input_logprob_sent = True
|
||||||
output_token_ids_logprobs_idx.append(
|
else:
|
||||||
req.output_token_ids_logprobs_idx
|
input_token_logprobs_val.append([])
|
||||||
)
|
input_token_logprobs_idx.append([])
|
||||||
|
input_top_logprobs_val.append([])
|
||||||
|
input_top_logprobs_idx.append([])
|
||||||
|
input_token_ids_logprobs_val.append([])
|
||||||
|
input_token_ids_logprobs_idx.append([])
|
||||||
|
|
||||||
|
if req.return_logprob:
|
||||||
|
output_token_logprobs_val.append(
|
||||||
|
req.output_token_logprobs_val[
|
||||||
|
send_output_token_logprobs_offset:
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_token_logprobs_idx.append(
|
||||||
|
req.output_token_logprobs_idx[
|
||||||
|
send_output_token_logprobs_offset:
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_top_logprobs_val.append(
|
||||||
|
req.output_top_logprobs_val[
|
||||||
|
send_output_token_logprobs_offset:
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_top_logprobs_idx.append(
|
||||||
|
req.output_top_logprobs_idx[
|
||||||
|
send_output_token_logprobs_offset:
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_token_ids_logprobs_val.append(
|
||||||
|
req.output_token_ids_logprobs_val[
|
||||||
|
send_output_token_logprobs_offset:
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_token_ids_logprobs_idx.append(
|
||||||
|
req.output_token_ids_logprobs_idx[
|
||||||
|
send_output_token_logprobs_offset:
|
||||||
|
]
|
||||||
|
)
|
||||||
|
req.send_output_token_logprobs_offset = len(
|
||||||
|
req.output_token_logprobs_val
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_token_logprobs_val.append([])
|
||||||
|
output_token_logprobs_idx.append([])
|
||||||
|
output_top_logprobs_val.append([])
|
||||||
|
output_top_logprobs_idx.append([])
|
||||||
|
output_token_ids_logprobs_val.append([])
|
||||||
|
output_token_ids_logprobs_idx.append([])
|
||||||
|
|
||||||
if req.return_hidden_states:
|
if req.return_hidden_states:
|
||||||
if output_hidden_states is None:
|
if output_hidden_states is None:
|
||||||
output_hidden_states = []
|
output_hidden_states = []
|
||||||
output_hidden_states.append(req.hidden_states)
|
output_hidden_states.append(req.hidden_states)
|
||||||
|
|
||||||
|
if (
|
||||||
|
req.finished()
|
||||||
|
and self.tp_rank == 0
|
||||||
|
and self.server_args.enable_request_time_stats_logging
|
||||||
|
):
|
||||||
|
req.log_time_stats()
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if rids:
|
if rids:
|
||||||
if self.model_config.is_multimodal_gen:
|
if self.model_config.is_multimodal_gen:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
rids,
|
rids,
|
||||||
|
|||||||
@@ -125,10 +125,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class ReqState:
|
class ReqState:
|
||||||
"""Store the state a request."""
|
"""Store the state a request."""
|
||||||
|
|
||||||
out_list: List
|
out_list: List[Dict[Any, Any]]
|
||||||
finished: bool
|
finished: bool
|
||||||
event: asyncio.Event
|
event: asyncio.Event
|
||||||
obj: Any
|
obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||||
|
|
||||||
# For metrics
|
# For metrics
|
||||||
created_time: float
|
created_time: float
|
||||||
@@ -139,6 +139,21 @@ class ReqState:
|
|||||||
|
|
||||||
# For streaming output
|
# For streaming output
|
||||||
last_output_offset: int = 0
|
last_output_offset: int = 0
|
||||||
|
# For incremental state update.
|
||||||
|
text: str = ""
|
||||||
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||||
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||||
|
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||||
|
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||||
|
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||||
|
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||||
|
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
||||||
|
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||||
|
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
||||||
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
@@ -1065,9 +1080,11 @@ class TokenizerManager:
|
|||||||
if getattr(state.obj, "return_logprob", False):
|
if getattr(state.obj, "return_logprob", False):
|
||||||
self.convert_logprob_style(
|
self.convert_logprob_style(
|
||||||
meta_info,
|
meta_info,
|
||||||
|
state,
|
||||||
state.obj.top_logprobs_num,
|
state.obj.top_logprobs_num,
|
||||||
state.obj.token_ids_logprob,
|
state.obj.token_ids_logprob,
|
||||||
state.obj.return_text_in_logprobs,
|
state.obj.return_text_in_logprobs
|
||||||
|
and not self.server_args.skip_tokenizer_init,
|
||||||
recv_obj,
|
recv_obj,
|
||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
@@ -1084,18 +1101,19 @@ class TokenizerManager:
|
|||||||
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
||||||
|
|
||||||
if isinstance(recv_obj, BatchStrOut):
|
if isinstance(recv_obj, BatchStrOut):
|
||||||
|
state.text += recv_obj.output_strs[i]
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"text": recv_obj.output_strs[i],
|
"text": state.text,
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
if self.server_args.stream_output and state.obj.stream:
|
if self.server_args.stream_output and state.obj.stream:
|
||||||
output_token_ids = recv_obj.output_ids[i][
|
state.output_ids.extend(recv_obj.output_ids[i])
|
||||||
state.last_output_offset :
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
||||||
]
|
state.last_output_offset = len(state.output_ids)
|
||||||
state.last_output_offset = len(recv_obj.output_ids[i])
|
|
||||||
else:
|
else:
|
||||||
output_token_ids = recv_obj.output_ids[i]
|
state.output_ids.extend(recv_obj.output_ids[i])
|
||||||
|
output_token_ids = state.output_ids
|
||||||
|
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"output_ids": output_token_ids,
|
"output_ids": output_token_ids,
|
||||||
@@ -1130,45 +1148,85 @@ class TokenizerManager:
|
|||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self,
|
self,
|
||||||
meta_info: dict,
|
meta_info: dict,
|
||||||
|
state: ReqState,
|
||||||
top_logprobs_num: int,
|
top_logprobs_num: int,
|
||||||
token_ids_logprob: List[int],
|
token_ids_logprob: List[int],
|
||||||
return_text_in_logprobs: bool,
|
return_text_in_logprobs: bool,
|
||||||
recv_obj: BatchStrOut,
|
recv_obj: BatchStrOut,
|
||||||
recv_obj_index: int,
|
recv_obj_index: int,
|
||||||
):
|
):
|
||||||
|
if len(recv_obj.input_token_logprobs_val) > 0:
|
||||||
|
state.input_token_logprobs_val.extend(
|
||||||
|
recv_obj.input_token_logprobs_val[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.input_token_logprobs_idx.extend(
|
||||||
|
recv_obj.input_token_logprobs_idx[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.output_token_logprobs_val.extend(
|
||||||
|
recv_obj.output_token_logprobs_val[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.output_token_logprobs_idx.extend(
|
||||||
|
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
||||||
|
)
|
||||||
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
recv_obj.input_token_logprobs_val[recv_obj_index],
|
state.input_token_logprobs_val,
|
||||||
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
state.input_token_logprobs_idx,
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
recv_obj.output_token_logprobs_val[recv_obj_index],
|
state.output_token_logprobs_val,
|
||||||
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
state.output_token_logprobs_idx,
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if top_logprobs_num > 0:
|
if top_logprobs_num > 0:
|
||||||
|
if len(recv_obj.input_top_logprobs_val) > 0:
|
||||||
|
state.input_top_logprobs_val.extend(
|
||||||
|
recv_obj.input_top_logprobs_val[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.input_top_logprobs_idx.extend(
|
||||||
|
recv_obj.input_top_logprobs_idx[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.output_top_logprobs_val.extend(
|
||||||
|
recv_obj.output_top_logprobs_val[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.output_top_logprobs_idx.extend(
|
||||||
|
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
||||||
|
)
|
||||||
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
recv_obj.input_top_logprobs_val[recv_obj_index],
|
state.input_top_logprobs_val,
|
||||||
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
state.input_top_logprobs_idx,
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
recv_obj.output_top_logprobs_val[recv_obj_index],
|
state.output_top_logprobs_val,
|
||||||
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
state.output_top_logprobs_idx,
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if token_ids_logprob is not None:
|
if token_ids_logprob is not None:
|
||||||
|
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
||||||
|
state.input_token_ids_logprobs_val.extend(
|
||||||
|
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.input_token_ids_logprobs_idx.extend(
|
||||||
|
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.output_token_ids_logprobs_val.extend(
|
||||||
|
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
|
||||||
|
)
|
||||||
|
state.output_token_ids_logprobs_idx.extend(
|
||||||
|
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
||||||
|
)
|
||||||
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
|
state.input_token_ids_logprobs_val,
|
||||||
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
|
state.input_token_ids_logprobs_idx,
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
meta_info["output_token_ids_logprobs"] = (
|
meta_info["output_token_ids_logprobs"] = (
|
||||||
self.detokenize_top_logprobs_tokens(
|
self.detokenize_top_logprobs_tokens(
|
||||||
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
|
state.output_token_ids_logprobs_val,
|
||||||
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
|
state.output_token_ids_logprobs_idx,
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -127,10 +127,12 @@ class TpModelWorkerClient:
|
|||||||
batch_lists = [None] * 2
|
batch_lists = [None] * 2
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
||||||
if not model_worker_batch:
|
if not model_worker_batch:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
sync_event.wait()
|
||||||
|
|
||||||
# Keep a reference of model_worker_batch by storing it into a list.
|
# Keep a reference of model_worker_batch by storing it into a list.
|
||||||
# Otherwise, the tensor members of model_worker_batch will be released
|
# Otherwise, the tensor members of model_worker_batch will be released
|
||||||
# by pytorch and cause CUDA illegal memory access errors.
|
# by pytorch and cause CUDA illegal memory access errors.
|
||||||
@@ -214,10 +216,11 @@ class TpModelWorkerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||||
self.scheduler_stream.synchronize()
|
sync_event = torch.get_device_module(self.device).Event()
|
||||||
|
sync_event.record(self.scheduler_stream)
|
||||||
|
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
||||||
|
|
||||||
# Allocate output future objects
|
# Allocate output future objects
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
|
|||||||
@@ -307,5 +307,5 @@ class SamplingBatchInfo:
|
|||||||
other_val = getattr(other, item, None)
|
other_val = getattr(other, item, None)
|
||||||
setattr(self, item, torch.cat([self_val, other_val]))
|
setattr(self, item, torch.cat([self_val, other_val]))
|
||||||
|
|
||||||
self.is_all_greedy |= other.is_all_greedy
|
self.is_all_greedy &= other.is_all_greedy
|
||||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class ServerArgs:
|
|||||||
show_time_cost: bool = False
|
show_time_cost: bool = False
|
||||||
enable_metrics: bool = False
|
enable_metrics: bool = False
|
||||||
decode_log_interval: int = 40
|
decode_log_interval: int = 40
|
||||||
|
enable_request_time_stats_logging: bool = False
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
@@ -785,6 +786,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.decode_log_interval,
|
default=ServerArgs.decode_log_interval,
|
||||||
help="The log interval of decode batch.",
|
help="The log interval of decode batch.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-request-time-stats-logging",
|
||||||
|
action="store_true",
|
||||||
|
default=ServerArgs.enable_request_time_stats_logging,
|
||||||
|
help="Enable per request time stats logging",
|
||||||
|
)
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user