From 26f0bedc8f351ed9b67d9b85ee30aa0c5f2aef45 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 5 Feb 2024 16:50:37 +0800 Subject: [PATCH] jump-forward rename (#144) --- .../README.md | 0 .../bench_other.py | 2 +- .../bench_sglang.py | 2 +- .../build_dataset.py | 0 .../dataset.txt | 0 .../{fast_forward.py => jump_forward.py} | 42 +++++++++--------- .../srt/managers/detokenizer_manager.py | 2 +- python/sglang/srt/managers/io_struct.py | 2 +- .../sglang/srt/managers/router/infer_batch.py | 44 +++++++++---------- .../sglang/srt/managers/router/model_rpc.py | 32 +++++++------- python/sglang/srt/server_args.py | 6 +-- ...t_fast_forward.py => test_jump_forward.py} | 8 ++-- 12 files changed, 70 insertions(+), 70 deletions(-) rename benchmark/{json_fast_forward => json_jump_forward}/README.md (100%) rename benchmark/{json_fast_forward => json_jump_forward}/bench_other.py (99%) rename benchmark/{json_fast_forward => json_jump_forward}/bench_sglang.py (99%) rename benchmark/{json_fast_forward => json_jump_forward}/build_dataset.py (100%) rename benchmark/{json_fast_forward => json_jump_forward}/dataset.txt (100%) rename python/sglang/srt/constrained/{fast_forward.py => jump_forward.py} (60%) rename test/srt/{test_fast_forward.py => test_jump_forward.py} (96%) diff --git a/benchmark/json_fast_forward/README.md b/benchmark/json_jump_forward/README.md similarity index 100% rename from benchmark/json_fast_forward/README.md rename to benchmark/json_jump_forward/README.md diff --git a/benchmark/json_fast_forward/bench_other.py b/benchmark/json_jump_forward/bench_other.py similarity index 99% rename from benchmark/json_fast_forward/bench_other.py rename to benchmark/json_jump_forward/bench_other.py index 042ede64b..f43974b15 100644 --- a/benchmark/json_fast_forward/bench_other.py +++ b/benchmark/json_jump_forward/bench_other.py @@ -219,7 +219,7 @@ def main(args): with open(args.result_file, "a") as fout: value = { - "task": "json_fast_forward", + "task": "json_jump_forward", "backend": args.backend, "latency": round(latency, 3), "num_jsons": args.num_jsons, diff --git a/benchmark/json_fast_forward/bench_sglang.py b/benchmark/json_jump_forward/bench_sglang.py similarity index 99% rename from benchmark/json_fast_forward/bench_sglang.py rename to benchmark/json_jump_forward/bench_sglang.py index 7f92a704a..cc22dd8c3 100644 --- a/benchmark/json_fast_forward/bench_sglang.py +++ b/benchmark/json_jump_forward/bench_sglang.py @@ -122,7 +122,7 @@ def main(args): with open(args.result_file, "a") as fout: value = { - "task": "json_fast_forward", + "task": "json_jump_forward", "backend": args.backend, "latency": round(latency, 3), "num_jsons": args.num_jsons, diff --git a/benchmark/json_fast_forward/build_dataset.py b/benchmark/json_jump_forward/build_dataset.py similarity index 100% rename from benchmark/json_fast_forward/build_dataset.py rename to benchmark/json_jump_forward/build_dataset.py diff --git a/benchmark/json_fast_forward/dataset.txt b/benchmark/json_jump_forward/dataset.txt similarity index 100% rename from benchmark/json_fast_forward/dataset.txt rename to benchmark/json_jump_forward/dataset.txt diff --git a/python/sglang/srt/constrained/fast_forward.py b/python/sglang/srt/constrained/jump_forward.py similarity index 60% rename from python/sglang/srt/constrained/fast_forward.py rename to python/sglang/srt/constrained/jump_forward.py index d6bb94cb9..022ec3197 100644 --- a/python/sglang/srt/constrained/fast_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -6,10 +6,10 @@ from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" -class FastForwardMap: +class JumpForwardMap: def __init__(self, regex_string): @disk_cache() - def _init_state_to_fast_forward(regex_string): + def _init_state_to_jump_forward(regex_string): regex_pattern = interegular.parse_pattern(regex_string) regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) @@ -22,54 +22,54 @@ class FastForwardMap: transitions = fsm_info.transitions dirty_states = set() - state_to_fast_forward = {} + state_to_jump_forward = {} for (state, id_), next_state in transitions.items(): if state in dirty_states: continue - if state in state_to_fast_forward: + if state in state_to_jump_forward: dirty_states.add(state) - del state_to_fast_forward[state] + del state_to_jump_forward[state] continue if len(id_to_symbol[id_]) > 1: dirty_states.add(state) continue - state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state) + state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state) - return state_to_fast_forward + return state_to_jump_forward - self.state_to_fast_forward = _init_state_to_fast_forward(regex_string) + self.state_to_jump_forward = _init_state_to_jump_forward(regex_string) def valid_states(self): - return self.state_to_fast_forward.keys() + return self.state_to_jump_forward.keys() - def fast_forward(self, state): - if state not in self.state_to_fast_forward: + def jump_forward(self, state): + if state not in self.state_to_jump_forward: return None - fast_forward_str = "" + jump_forward_str = "" next_state = None - while state in self.state_to_fast_forward: - symbol, next_state = self.state_to_fast_forward[state] - fast_forward_str += symbol + while state in self.state_to_jump_forward: + symbol, next_state = self.state_to_jump_forward[state] + jump_forward_str += symbol state = next_state - return fast_forward_str, next_state + return jump_forward_str, next_state -class FastForwardCache(BaseCache): +class JumpForwardCache(BaseCache): def __init__(self): super().__init__() def init_value(self, regex): - return FastForwardMap(regex) + return JumpForwardMap(regex) def test_main(): regex_string = r"The google's DNS sever address is " + IP_REGEX - fast_forward_map = FastForwardMap(regex_string) - for state in fast_forward_map.valid_states(): - print(state, f'"{fast_forward_map.fast_forward(state)}"') + jump_forward_map = JumpForwardMap(regex_string) + for state in jump_forward_map.valid_states(): + print(state, f'"{jump_forward_map.jump_forward(state)}"') if __name__ == "__main__": diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index f66f42342..566d40d13 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -61,7 +61,7 @@ class DetokenizerManager: output_strs[i] = " " + output_strs[i] output_strs[i] = ( - recv_obj.output_and_fast_forward_strs[i] + output_strs[i] + recv_obj.output_and_jump_forward_strs[i] + output_strs[i] ) self.send_to_tokenizer.send_pyobj( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c1ce125ff..4f2f4522a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -81,7 +81,7 @@ class TokenizedGenerateReqInput: class BatchTokenIDOut: rids: List[str] output_tokens: List[List[int]] - output_and_fast_forward_strs: List[str] + output_and_jump_forward_strs: List[str] hit_stop_str: List[Optional[str]] skip_special_tokens: List[bool] meta_info: List[Dict] diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 5a3cc0897..339e003de 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -53,13 +53,13 @@ class Req: # For constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 - self.fast_forward_map = None - self.output_and_fast_forward_str = "" + self.jump_forward_map = None + self.output_and_jump_forward_str = "" def max_new_tokens(self): return self.sampling_params.max_new_tokens - def fast_forward_and_retokenize(self, fast_forward_str, next_state): + def jump_forward_and_retokenize(self, jump_forward_str, next_state): old_output_str = self.tokenizer.decode(self.output_ids) # FIXME: This logic does not really solve the problem of determining whether # there should be a leading space. @@ -71,35 +71,35 @@ class Req: old_output_str = " " + old_output_str new_input_string = ( self.input_text - + self.output_and_fast_forward_str + + self.output_and_jump_forward_str + old_output_str - + fast_forward_str + + jump_forward_str ) new_input_ids = self.tokenizer.encode(new_input_string) if self.pixel_values is not None: # NOTE: This is a hack because the old input_ids contains the image padding - fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str)) + jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str)) else: - fast_forward_tokens_len = ( + jump_forward_tokens_len = ( len(new_input_ids) - len(self.input_ids) - len(self.output_ids) ) # print("=" * 100) - # print(f"Catch fast forward:\n{fast_forward_str}") + # print(f"Catch jump forward:\n{jump_forward_str}") # print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) # print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) self.input_ids = new_input_ids self.output_ids = [] self.sampling_params.max_new_tokens = max( - self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0 + self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0 ) self.regex_fsm_state = next_state - self.output_and_fast_forward_str = ( - self.output_and_fast_forward_str + old_output_str + fast_forward_str + self.output_and_jump_forward_str = ( + self.output_and_jump_forward_str + old_output_str + jump_forward_str ) - # print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}") + # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") # print("*" * 100) def check_finished(self): @@ -327,18 +327,18 @@ class Batch: return retracted_reqs - def check_for_fast_forward(self): - fast_forward_reqs = [] + def check_for_jump_forward(self): + jump_forward_reqs = [] filter_indices = [i for i in range(len(self.reqs))] req_pool_indices_cpu = None for i, req in enumerate(self.reqs): - if req.fast_forward_map is not None: - res = req.fast_forward_map.fast_forward(req.regex_fsm_state) + if req.jump_forward_map is not None: + res = req.jump_forward_map.jump_forward(req.regex_fsm_state) if res is not None: - fast_forward_str, next_state = res - if len(fast_forward_str) <= 1: + jump_forward_str, next_state = res + if len(jump_forward_str) <= 1: continue # insert the old request into tree_cache @@ -356,16 +356,16 @@ class Batch: self.req_to_token_pool.free(req_pool_idx) self.tree_cache.dec_ref_counter(req.last_node) - # fast forward - req.fast_forward_and_retokenize(fast_forward_str, next_state) + # jump-forward + req.jump_forward_and_retokenize(jump_forward_str, next_state) - fast_forward_reqs.append(req) + jump_forward_reqs.append(req) filter_indices.remove(i) if len(filter_indices) < len(self.reqs): self.filter_batch(filter_indices) - return fast_forward_reqs + return jump_forward_reqs def prepare_for_decode(self, input_ids=None): if input_ids is None: diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 49da99d96..4700b6311 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -11,7 +11,7 @@ import rpyc import torch from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer -from sglang.srt.constrained.fast_forward import FastForwardCache +from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( @@ -49,7 +49,7 @@ class ModelRpcServer(rpyc.Service): self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic - self.no_regex_fast_forward = server_args.no_regex_fast_forward + self.no_regex_jump_forward = server_args.no_regex_jump_forward # Init model and tokenizer self.model_config = ModelConfig( @@ -127,7 +127,7 @@ class ModelRpcServer(rpyc.Service): "trust_remote_code": server_args.trust_remote_code, }, ) - self.fast_forward_cache = FastForwardCache() + self.jump_forward_cache = JumpForwardCache() # Init new token estimation self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) @@ -254,8 +254,8 @@ class ModelRpcServer(rpyc.Service): # Init regex fsm if req.sampling_params.regex is not None: req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) - if not self.no_regex_fast_forward: - req.fast_forward_map = self.fast_forward_cache.query( + if not self.no_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( req.sampling_params.regex ) @@ -369,8 +369,8 @@ class ModelRpcServer(rpyc.Service): logger.debug( f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " - f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. " - f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. " + f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " + f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " ) new_batch = Batch.init_new( @@ -437,12 +437,12 @@ class ModelRpcServer(rpyc.Service): self.min_new_token_ratio, ) - if not self.no_regex_fast_forward: - # check for fast forward - fast_forward_reqs = batch.check_for_fast_forward() + if not self.no_regex_jump_forward: + # check for jump-forward + jump_forward_reqs = batch.check_for_jump_forward() - # check for image fast forward - for req in fast_forward_reqs: + # check for image jump-forward + for req in jump_forward_reqs: if req.pixel_values is not None: ( req.input_ids, @@ -454,7 +454,7 @@ class ModelRpcServer(rpyc.Service): req.image_size, ) - self.forward_queue.extend(fast_forward_reqs) + self.forward_queue.extend(jump_forward_reqs) if batch.is_empty(): return @@ -478,7 +478,7 @@ class ModelRpcServer(rpyc.Service): def handle_finished_requests(self, batch: Batch): output_rids = [] output_tokens = [] - output_and_fast_forward_strs = [] + output_and_jump_forward_strs = [] output_hit_stop_str = [] output_skip_special_tokens = [] output_meta_info = [] @@ -502,7 +502,7 @@ class ModelRpcServer(rpyc.Service): ): output_rids.append(req.rid) output_tokens.append(req.output_ids) - output_and_fast_forward_strs.append(req.output_and_fast_forward_str) + output_and_jump_forward_strs.append(req.output_and_jump_forward_str) output_hit_stop_str.append(req.hit_stop_str) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens @@ -523,7 +523,7 @@ class ModelRpcServer(rpyc.Service): BatchTokenIDOut( output_rids, output_tokens, - output_and_fast_forward_strs, + output_and_jump_forward_strs, output_hit_stop_str, output_skip_special_tokens, output_meta_info, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 866a93ac2..f33f86eb6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -25,7 +25,7 @@ class ServerArgs: disable_log_stats: bool = False log_stats_interval: int = 10 log_level: str = "info" - no_regex_fast_forward: bool = False + no_regex_jump_forward: bool = False def __post_init__(self): if self.tokenizer_path is None: @@ -172,9 +172,9 @@ class ServerArgs: help="Log stats interval in second.", ) parser.add_argument( - "--no-regex-fast-forward", + "--no-regex-jump-forward", action="store_true", - help="Disable regex fast forward", + help="Disable regex jump-forward", ) @classmethod diff --git a/test/srt/test_fast_forward.py b/test/srt/test_jump_forward.py similarity index 96% rename from test/srt/test_fast_forward.py rename to test/srt/test_jump_forward.py index 505ce2928..f171aae28 100644 --- a/test/srt/test_fast_forward.py +++ b/test/srt/test_jump_forward.py @@ -12,7 +12,7 @@ import sglang as sgl IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" -ip_fast_forward = ( +ip_jump_forward = ( r"The google's DNS sever address is " + IP_REGEX + r" and " @@ -32,11 +32,11 @@ def regex_gen(s): "answer", max_tokens=128, temperature=0, - regex=ip_fast_forward, + regex=ip_jump_forward, ) # fmt: on -json_fast_forward = ( +json_jump_forward = ( r"""The information about Hogwarts is in the following JSON format\.\n""" + r"""\n\{\n""" + r""" "name": "[\w\d\s]*",\n""" @@ -54,7 +54,7 @@ def json_gen(s): "json", max_tokens=128, temperature=0, - regex=json_fast_forward, + regex=json_jump_forward, ) # fmt: on