diff --git a/examples/usage/json_logprobs.py b/examples/usage/json_logprobs.py new file mode 100644 index 000000000..6b5b9c8fc --- /dev/null +++ b/examples/usage/json_logprobs.py @@ -0,0 +1,104 @@ +# NOTE: Currently this can only be run through HTTP requests. +import json +from concurrent.futures import ThreadPoolExecutor + +from json_decode import character_regex + +from sglang.utils import http_request + +character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"] + +base_url = "http://localhost:30000" + +prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n" + + +def openai_api_request(name): + data = { + "model": "", + "prompt": name + prompt, + "temperature": 0, + "max_tokens": 128, + "regex": character_regex, + "logprobs": 3, + } + res = http_request(base_url + "/v1/completions", json=data).json() + + # with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout: + # fout.write(json.dumps(res, indent=4)) + + logprobs = res["choices"][0]["logprobs"] + usage = res["usage"] + assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"]) + assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"]) + assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1 + + return res + + +def srt_api_request(name): + data = { + "text": name + prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 128, + "regex": character_regex, + }, + "return_logprob": True, + "logprob_start_len": 0, + "top_logprobs_num": 3, + "return_text_in_logprobs": True, + } + + res = http_request(base_url + "/generate", json=data).json() + + # with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout: + # fout.write(json.dumps(res, indent=4)) + + meta_info = res["meta_info"] + assert len(meta_info["prefill_token_logprobs"]) == len( + meta_info["prefill_top_logprobs"] + ) + assert len(meta_info["decode_token_logprobs"]) == len( + meta_info["decode_top_logprobs"] + ) + assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"] + assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1 + + return res + + +def pretty_print(res): + meta_info = res["meta_info"] + + print("\n\n", "=" * 30, "Prefill", "=" * 30) + for i in range(len(meta_info["prefill_token_logprobs"])): + print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="") + top_ks = ( + [str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]] + if meta_info["prefill_top_logprobs"][i] + else [] + ) + for top_k in top_ks: + print(f"{top_k: <15}", end="") + print() + + print("\n\n", "=" * 30, "Decode", "=" * 30) + for i in range(len(meta_info["decode_token_logprobs"])): + print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="") + top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]] + for top_k in top_ks: + print(f"{top_k: <15}", end="") + print() + + print(res["text"]) + + +if __name__ == "__main__": + with ThreadPoolExecutor() as executor: + ress = executor.map(srt_api_request, character_names) + + for res in ress: + pretty_print(res) + + openai_api_request("Hermione Granger") diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 062628bd3..452412bec 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -28,5 +28,11 @@ class GlobalConfig: # Request dependency time due to network delay self.request_dependency_time = 0.03 + # New generation token ratio estimation + self.base_new_token_ratio = 0.4 + self.base_min_new_token_ratio = 0.2 + self.new_token_ratio_decay = 0.0001 + self.new_token_ratio_recovery = 0.05 + global_config = GlobalConfig() diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 53d6620e9..e47a286eb 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module): prefill_top_logprobs, decode_top_logprobs = [], [] pt = 0 # NOTE: the GPU-CPU overhead can be reduced - extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy() - for i in range(len(extend_seq_lens_cpu)): - if extend_seq_lens_cpu[i] == 0: + extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist() + for i, extend_seq_len in enumerate(extend_seq_lens_cpu): + if extend_seq_len == 0: prefill_top_logprobs.append([]) decode_top_logprobs.append([]) continue k = input_metadata.top_logprobs_nums[i] - t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) + t = all_logprobs[pt : pt + extend_seq_len].topk(k) vs_cpu = t.values.tolist() ps_cpu = t.indices.tolist() prefill_top_logprobs.append( [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] ) decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) - pt += extend_seq_lens_cpu[i] + pt += extend_seq_len + return prefill_top_logprobs, decode_top_logprobs def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): @@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module): ) -if __name__ == "__main__": +def test(): all_logprobs = torch.tensor( # s s s [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], @@ -173,3 +174,7 @@ if __name__ == "__main__": print("start", start) print("end", end) print("sum_logp", sum_logp) + + +if __name__ == "__main__": + test() diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index eeefbe0ba..4774dba33 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -51,11 +51,6 @@ class DetokenizerManager: # Trim stop str # TODO(lmzheng): handle the case where multiple stop strs are hit for i in range(len(output_strs)): - if recv_obj.hit_stop_str[i] is not None: - pos = output_strs[i].find(recv_obj.hit_stop_str[i]) - if pos != -1: - output_strs[i] = output_strs[i][:pos] - if len(output_tokens[i]) > 0: first_token = self.tokenizer.convert_ids_to_tokens( int(output_tokens[i][0]) @@ -65,9 +60,12 @@ class DetokenizerManager: if first_token.startswith("▁"): output_strs[i] = " " + output_strs[i] - output_strs[i] = ( - recv_obj.output_and_jump_forward_strs[i] + output_strs[i] - ) + output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i] + + if recv_obj.hit_stop_str[i] is not None: + pos = output_strs[i].find(recv_obj.hit_stop_str[i]) + if pos != -1: + output_strs[i] = output_strs[i][:pos] self.send_to_tokenizer.send_pyobj( BatchStrOut( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8da2317c1..4e8d6d74a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -106,8 +106,8 @@ class TokenizedGenerateReqInput: @dataclass class BatchTokenIDOut: rids: List[str] + prev_output_strs : List[str] output_tokens: List[List[int]] - output_and_jump_forward_strs: List[str] hit_stop_str: List[Optional[str]] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index dbe94371b..20cc662a0 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -36,15 +36,15 @@ class FinishReason(IntEnum): class Req: - def __init__(self, rid, input_text, input_ids): + def __init__(self, rid, origin_input_text, origin_input_ids): self.rid = rid - self.input_text = input_text - self.input_ids = input_ids + self.origin_input_text = origin_input_text + self.origin_input_ids = origin_input_ids + self.origin_input_ids_unpadded = origin_input_ids # before image padding + self.prev_output_str = "" + self.prev_output_ids = [] self.output_ids = [] - - # Since jump forward may retokenize the prompt with partial outputs, - # we maintain the original prompt length to report the correct usage. - self.prompt_tokens = len(input_ids) + self.input_ids = None # input_ids = origin_input_ids + prev_output_ids # The number of decoded tokens for token usage report. Note that # this does not include the jump forward tokens. @@ -76,15 +76,24 @@ class Req: self.top_logprobs_num = 0 self.normalized_prompt_logprob = None self.prefill_token_logprobs = None - self.decode_token_logprobs = None + self.decode_token_logprobs = [] self.prefill_top_logprobs = None - self.decode_top_logprobs = None + self.decode_top_logprobs = [] + # The tokens is prefilled but need to be considered as decode tokens + # and should be updated for the decode logprobs + self.last_update_decode_tokens = 0 # Constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 self.jump_forward_map = None - self.output_and_jump_forward_str = "" + + def partial_decode(self, ids): + first_token = self.tokenizer.convert_ids_to_tokens(ids[0]) + first_token = ( + first_token.decode() if isinstance(first_token, bytes) else first_token + ) + return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids) def max_new_tokens(self): return self.sampling_params.max_new_tokens @@ -93,7 +102,10 @@ class Req: if self.finished: return - if len(self.output_ids) >= self.sampling_params.max_new_tokens: + if ( + len(self.prev_output_ids) + len(self.output_ids) + >= self.sampling_params.max_new_tokens + ): self.finished = True self.finish_reason = FinishReason.LENGTH return @@ -112,60 +124,66 @@ class Req: ) for stop_str in self.sampling_params.stop_strs: - if stop_str in tail_str: + # FIXME: (minor) try incremental match in prev_output_str + if stop_str in tail_str or stop_str in self.prev_output_str: self.finished = True self.finish_reason = FinishReason.STOP_STR self.hit_stop_str = stop_str return def jump_forward_and_retokenize(self, jump_forward_str, next_state): - old_output_str = self.tokenizer.decode(self.output_ids) # FIXME: This logic does not really solve the problem of determining whether # there should be a leading space. - first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0]) - first_token = ( - first_token.decode() if isinstance(first_token, bytes) else first_token - ) - if first_token.startswith("▁"): - old_output_str = " " + old_output_str - if self.input_text is None: - # TODO(lmzheng): This can be wrong. Check with Liangsheng. - self.input_text = self.tokenizer.decode(self.input_ids) - new_input_string = ( - self.input_text - + self.output_and_jump_forward_str - + old_output_str + cur_output_str = self.partial_decode(self.output_ids) + + # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore + if self.origin_input_text is None: + # Recovering text can only use unpadded ids + self.origin_input_text = self.tokenizer.decode( + self.origin_input_ids_unpadded + ) + + all_text = ( + self.origin_input_text + + self.prev_output_str + + cur_output_str + jump_forward_str ) - new_input_ids = self.tokenizer.encode(new_input_string) - if self.pixel_values is not None: - # NOTE: This is a hack because the old input_ids contains the image padding - jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str)) - else: - jump_forward_tokens_len = ( - len(new_input_ids) - len(self.input_ids) - len(self.output_ids) - ) + all_ids = self.tokenizer.encode(all_text) + prompt_tokens = len(self.origin_input_ids_unpadded) + self.origin_input_ids = all_ids[:prompt_tokens] + self.origin_input_ids_unpadded = self.origin_input_ids + # NOTE: the output ids may not strictly correspond to the output text + old_prev_output_ids = self.prev_output_ids + self.prev_output_ids = all_ids[prompt_tokens:] + self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str + self.output_ids = [] + + self.regex_fsm_state = next_state + + if self.return_logprob: + # For fast-forward part's logprobs + k = 0 + for i, old_id in enumerate(old_prev_output_ids): + if old_id == self.prev_output_ids[i]: + k = k + 1 + else: + break + self.decode_token_logprobs = self.decode_token_logprobs[:k] + self.decode_top_logprobs = self.decode_top_logprobs[:k] + self.logprob_start_len = prompt_tokens + k + self.last_update_decode_tokens = len(self.prev_output_ids) - k # print("=" * 100) # print(f"Catch jump forward:\n{jump_forward_str}") # print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) # print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) - self.input_ids = new_input_ids - self.output_ids = [] - self.sampling_params.max_new_tokens = max( - self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0 - ) - self.regex_fsm_state = next_state - self.output_and_jump_forward_str = ( - self.output_and_jump_forward_str + old_output_str + jump_forward_str - ) - # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") # print("*" * 100) def __repr__(self): - return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " + return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " @dataclass @@ -336,6 +354,7 @@ class Batch: def retract_decode(self): sorted_indices = [i for i in range(len(self.reqs))] + # TODO(lsyin): improve the priority of retraction sorted_indices.sort( key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)), reverse=True, @@ -356,18 +375,27 @@ class Batch: ][last_uncached_pos : seq_lens_cpu[idx]] self.token_to_kv_pool.dec_refs(token_indices) + # release the last node self.tree_cache.dec_lock_ref(req.last_node) + + cur_output_str = req.partial_decode(req.output_ids) + req.prev_output_str = req.prev_output_str + cur_output_str + req.prev_output_ids.extend(req.output_ids) + req.prefix_indices = None req.last_node = None req.extend_input_len = 0 req.output_ids = [] - req.regex_fsm_state = 0 + + # For incremental logprobs + req.last_update_decode_tokens = 0 + req.logprob_start_len = 10**9 self.filter_batch(sorted_indices) return retracted_reqs - def check_for_jump_forward(self): + def check_for_jump_forward(self, model_runner): jump_forward_reqs = [] filter_indices = [i for i in range(len(self.reqs))] @@ -397,6 +425,18 @@ class Batch: # jump-forward req.jump_forward_and_retokenize(jump_forward_str, next_state) + # re-applying image padding + if req.pixel_values is not None: + ( + req.origin_input_ids, + req.image_offset, + ) = model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values.shape, + req.image_size, + ) + jump_forward_reqs.append(req) filter_indices.remove(i) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 6abb20b25..d52b3767d 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -4,7 +4,7 @@ import multiprocessing import time import warnings from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional import rpyc import torch @@ -16,6 +16,7 @@ try: except ImportError: from vllm.logger import logger as vllm_default_logger +from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -106,7 +107,8 @@ class ModelRpcServer: set_random_seed(server_args.random_seed) # Print info - logger.info(f"[rank={self.tp_rank}] " + logger.info( + f"[rank={self.tp_rank}] " f"max_total_num_token={self.max_total_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, " f"context_len={self.model_config.context_len}, " @@ -151,9 +153,20 @@ class ModelRpcServer: self.jump_forward_cache = JumpForwardCache() # Init new token estimation - self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) - self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0) - self.new_token_ratio_step = (0.0001, 0.05) # (down, up) + assert ( + server_args.schedule_conservativeness >= 0 + ), "Invalid schedule_conservativeness" + self.new_token_ratio = min( + global_config.base_new_token_ratio * server_args.schedule_conservativeness, + 1.0, + ) + self.min_new_token_ratio = min( + global_config.base_min_new_token_ratio + * server_args.schedule_conservativeness, + 1.0, + ) + self.new_token_ratio_decay = global_config.new_token_ratio_decay + self.new_token_ratio_recovery = global_config.new_token_ratio_recovery def exposed_step(self, recv_reqs): if self.tp_size != 1: @@ -256,8 +269,13 @@ class ModelRpcServer: (recv_req.image_hash >> 64) % self.model_config.vocab_size, ] req.image_size = recv_req.image_size - req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( - req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size + req.origin_input_ids, req.image_offset = ( + self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values.shape, + req.image_size, + ) ) req.sampling_params = recv_req.sampling_params req.return_logprob = recv_req.return_logprob @@ -275,11 +293,11 @@ class ModelRpcServer: ) # Truncate prompts that are too long - req.input_ids = req.input_ids[: self.model_config.context_len - 1] + req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1] req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens, - self.model_config.context_len - 1 - len(req.input_ids), - self.max_total_num_token - 128 - len(req.input_ids), + self.model_config.context_len - 1 - len(req.origin_input_ids), + self.max_total_num_token - 128 - len(req.origin_input_ids), ) self.forward_queue.append(req) @@ -292,6 +310,10 @@ class ModelRpcServer: # Compute matched prefix length for req in self.forward_queue: + assert ( + len(req.output_ids) == 0 + ), "The output ids should be empty when prefilling" + req.input_ids = req.origin_input_ids + req.prev_output_ids prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) if req.return_logprob: prefix_indices = prefix_indices[: req.logprob_start_len] @@ -319,7 +341,7 @@ class ModelRpcServer: ) for req in self.forward_queue: - if req.return_logprob: + if req.return_logprob and req.normalized_prompt_logprob is None: # Need at least two tokens to compute normalized logprob if req.extend_input_len < 2: delta = 2 - req.extend_input_len @@ -441,28 +463,53 @@ class ModelRpcServer: req.check_finished() if req.return_logprob: - req.normalized_prompt_logprob = normalized_prompt_logprobs[i] + if req.normalized_prompt_logprob is None: + req.normalized_prompt_logprob = normalized_prompt_logprobs[i] - # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - req.prefill_token_logprobs = list( - zip( - prefill_token_logprobs[pt : pt + req.extend_input_len - 1], - req.input_ids[-req.extend_input_len + 1 :], + if req.prefill_token_logprobs is None: + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. + req.prefill_token_logprobs = list( + zip( + prefill_token_logprobs[pt : pt + req.extend_input_len - 1], + req.input_ids[-req.extend_input_len + 1 :], + ) ) - ) - if req.logprob_start_len == 0: - req.prefill_token_logprobs = [ - (None, req.input_ids[0]) - ] + req.prefill_token_logprobs - req.decode_token_logprobs = [ + if req.logprob_start_len == 0: + req.prefill_token_logprobs = [ + (None, req.input_ids[0]) + ] + req.prefill_token_logprobs + + if req.last_update_decode_tokens != 0: + req.decode_token_logprobs.extend( + list( + zip( + prefill_token_logprobs[ + pt + + req.extend_input_len + - req.last_update_decode_tokens : pt + + req.extend_input_len + - 1 + ], + req.input_ids[-req.last_update_decode_tokens + 1 :], + ) + ) + ) + + req.decode_token_logprobs.append( (last_token_logprobs[i], next_token_ids[i]) - ] + ) if req.top_logprobs_num > 0: - req.prefill_top_logprobs = prefill_top_logprobs[i] - if req.logprob_start_len == 0: - req.prefill_top_logprobs = [None] + req.prefill_top_logprobs - req.decode_top_logprobs = [decode_top_logprobs[i]] + if req.prefill_top_logprobs is None: + req.prefill_top_logprobs = prefill_top_logprobs[i] + if req.logprob_start_len == 0: + req.prefill_top_logprobs = [None] + req.prefill_top_logprobs + + if req.last_update_decode_tokens != 0: + req.decode_top_logprobs.extend( + prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] + ) + req.decode_top_logprobs.append(decode_top_logprobs[i]) pt += req.extend_input_len @@ -484,7 +531,7 @@ class ModelRpcServer: # check if decode out of memory if not batch.check_decode_mem(): old_ratio = self.new_token_ratio - self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0) + self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0) retracted_reqs = batch.retract_decode() logger.info( @@ -495,26 +542,13 @@ class ModelRpcServer: self.forward_queue.extend(retracted_reqs) else: self.new_token_ratio = max( - self.new_token_ratio - self.new_token_ratio_step[0], + self.new_token_ratio - self.new_token_ratio_decay, self.min_new_token_ratio, ) if not self.disable_regex_jump_forward: # check for jump-forward - jump_forward_reqs = batch.check_for_jump_forward() - - # check for image jump-forward - for req in jump_forward_reqs: - if req.pixel_values is not None: - ( - req.input_ids, - req.image_offset, - ) = self.model_runner.model.pad_input_ids( - req.input_ids, - req.pad_value, - req.pixel_values.shape, - req.image_size, - ) + jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) self.forward_queue.extend(jump_forward_reqs) if batch.is_empty(): @@ -557,8 +591,8 @@ class ModelRpcServer: def handle_finished_requests(self, batch: Batch): output_rids = [] + prev_output_strs = [] output_tokens = [] - output_and_jump_forward_strs = [] output_hit_stop_str = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] @@ -582,8 +616,8 @@ class ModelRpcServer: ) ): output_rids.append(req.rid) + prev_output_strs.append(req.prev_output_str) output_tokens.append(req.output_ids) - output_and_jump_forward_strs.append(req.output_and_jump_forward_str) output_hit_stop_str.append(req.hit_stop_str) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens @@ -593,10 +627,8 @@ class ModelRpcServer: ) meta_info = { - "prompt_tokens": req.prompt_tokens, - "completion_tokens": len(req.input_ids) - + len(req.output_ids) - - req.prompt_tokens, + "prompt_tokens": len(req.origin_input_ids), + "completion_tokens": len(req.prev_output_ids) + len(req.output_ids), "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "finish_reason": FinishReason.to_str(req.finish_reason), "hit_stop_str": req.hit_stop_str, @@ -623,8 +655,8 @@ class ModelRpcServer: self.out_pyobjs.append( BatchTokenIDOut( output_rids, + prev_output_strs, output_tokens, - output_and_jump_forward_strs, output_hit_stop_str, output_skip_special_tokens, output_spaces_between_special_tokens,