jump-forward rename (#144)
This commit is contained in:
@@ -6,10 +6,10 @@ from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
|
||||
class FastForwardMap:
|
||||
class JumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_fast_forward(regex_string):
|
||||
def _init_state_to_jump_forward(regex_string):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
|
||||
@@ -22,54 +22,54 @@ class FastForwardMap:
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
dirty_states = set()
|
||||
state_to_fast_forward = {}
|
||||
state_to_jump_forward = {}
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if state in dirty_states:
|
||||
continue
|
||||
if state in state_to_fast_forward:
|
||||
if state in state_to_jump_forward:
|
||||
dirty_states.add(state)
|
||||
del state_to_fast_forward[state]
|
||||
del state_to_jump_forward[state]
|
||||
continue
|
||||
if len(id_to_symbol[id_]) > 1:
|
||||
dirty_states.add(state)
|
||||
continue
|
||||
|
||||
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||
|
||||
return state_to_fast_forward
|
||||
return state_to_jump_forward
|
||||
|
||||
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
|
||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
||||
|
||||
def valid_states(self):
|
||||
return self.state_to_fast_forward.keys()
|
||||
return self.state_to_jump_forward.keys()
|
||||
|
||||
def fast_forward(self, state):
|
||||
if state not in self.state_to_fast_forward:
|
||||
def jump_forward(self, state):
|
||||
if state not in self.state_to_jump_forward:
|
||||
return None
|
||||
|
||||
fast_forward_str = ""
|
||||
jump_forward_str = ""
|
||||
next_state = None
|
||||
while state in self.state_to_fast_forward:
|
||||
symbol, next_state = self.state_to_fast_forward[state]
|
||||
fast_forward_str += symbol
|
||||
while state in self.state_to_jump_forward:
|
||||
symbol, next_state = self.state_to_jump_forward[state]
|
||||
jump_forward_str += symbol
|
||||
state = next_state
|
||||
return fast_forward_str, next_state
|
||||
return jump_forward_str, next_state
|
||||
|
||||
|
||||
class FastForwardCache(BaseCache):
|
||||
class JumpForwardCache(BaseCache):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def init_value(self, regex):
|
||||
return FastForwardMap(regex)
|
||||
return JumpForwardMap(regex)
|
||||
|
||||
|
||||
def test_main():
|
||||
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
||||
fast_forward_map = FastForwardMap(regex_string)
|
||||
for state in fast_forward_map.valid_states():
|
||||
print(state, f'"{fast_forward_map.fast_forward(state)}"')
|
||||
jump_forward_map = JumpForwardMap(regex_string)
|
||||
for state in jump_forward_map.valid_states():
|
||||
print(state, f'"{jump_forward_map.jump_forward(state)}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -61,7 +61,7 @@ class DetokenizerManager:
|
||||
output_strs[i] = " " + output_strs[i]
|
||||
|
||||
output_strs[i] = (
|
||||
recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
|
||||
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
|
||||
)
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
|
||||
@@ -81,7 +81,7 @@ class TokenizedGenerateReqInput:
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
output_tokens: List[List[int]]
|
||||
output_and_fast_forward_strs: List[str]
|
||||
output_and_jump_forward_strs: List[str]
|
||||
hit_stop_str: List[Optional[str]]
|
||||
skip_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
|
||||
@@ -53,13 +53,13 @@ class Req:
|
||||
# For constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = 0
|
||||
self.fast_forward_map = None
|
||||
self.output_and_fast_forward_str = ""
|
||||
self.jump_forward_map = None
|
||||
self.output_and_jump_forward_str = ""
|
||||
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
|
||||
def fast_forward_and_retokenize(self, fast_forward_str, next_state):
|
||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||
# FIXME: This logic does not really solve the problem of determining whether
|
||||
# there should be a leading space.
|
||||
@@ -71,35 +71,35 @@ class Req:
|
||||
old_output_str = " " + old_output_str
|
||||
new_input_string = (
|
||||
self.input_text
|
||||
+ self.output_and_fast_forward_str
|
||||
+ self.output_and_jump_forward_str
|
||||
+ old_output_str
|
||||
+ fast_forward_str
|
||||
+ jump_forward_str
|
||||
)
|
||||
new_input_ids = self.tokenizer.encode(new_input_string)
|
||||
if self.pixel_values is not None:
|
||||
# NOTE: This is a hack because the old input_ids contains the image padding
|
||||
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str))
|
||||
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
|
||||
else:
|
||||
fast_forward_tokens_len = (
|
||||
jump_forward_tokens_len = (
|
||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||
)
|
||||
|
||||
# print("=" * 100)
|
||||
# print(f"Catch fast forward:\n{fast_forward_str}")
|
||||
# print(f"Catch jump forward:\n{jump_forward_str}")
|
||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
||||
|
||||
self.input_ids = new_input_ids
|
||||
self.output_ids = []
|
||||
self.sampling_params.max_new_tokens = max(
|
||||
self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0
|
||||
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
|
||||
)
|
||||
self.regex_fsm_state = next_state
|
||||
self.output_and_fast_forward_str = (
|
||||
self.output_and_fast_forward_str + old_output_str + fast_forward_str
|
||||
self.output_and_jump_forward_str = (
|
||||
self.output_and_jump_forward_str + old_output_str + jump_forward_str
|
||||
)
|
||||
|
||||
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}")
|
||||
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
||||
# print("*" * 100)
|
||||
|
||||
def check_finished(self):
|
||||
@@ -327,18 +327,18 @@ class Batch:
|
||||
|
||||
return retracted_reqs
|
||||
|
||||
def check_for_fast_forward(self):
|
||||
fast_forward_reqs = []
|
||||
def check_for_jump_forward(self):
|
||||
jump_forward_reqs = []
|
||||
filter_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
req_pool_indices_cpu = None
|
||||
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.fast_forward_map is not None:
|
||||
res = req.fast_forward_map.fast_forward(req.regex_fsm_state)
|
||||
if req.jump_forward_map is not None:
|
||||
res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
|
||||
if res is not None:
|
||||
fast_forward_str, next_state = res
|
||||
if len(fast_forward_str) <= 1:
|
||||
jump_forward_str, next_state = res
|
||||
if len(jump_forward_str) <= 1:
|
||||
continue
|
||||
|
||||
# insert the old request into tree_cache
|
||||
@@ -356,16 +356,16 @@ class Batch:
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
# fast forward
|
||||
req.fast_forward_and_retokenize(fast_forward_str, next_state)
|
||||
# jump-forward
|
||||
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
||||
|
||||
fast_forward_reqs.append(req)
|
||||
jump_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
if len(filter_indices) < len(self.reqs):
|
||||
self.filter_batch(filter_indices)
|
||||
|
||||
return fast_forward_reqs
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_for_decode(self, input_ids=None):
|
||||
if input_ids is None:
|
||||
|
||||
@@ -11,7 +11,7 @@ import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
@@ -49,7 +49,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.no_regex_fast_forward = server_args.no_regex_fast_forward
|
||||
self.no_regex_jump_forward = server_args.no_regex_jump_forward
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -127,7 +127,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
)
|
||||
self.fast_forward_cache = FastForwardCache()
|
||||
self.jump_forward_cache = JumpForwardCache()
|
||||
|
||||
# Init new token estimation
|
||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||
@@ -254,8 +254,8 @@ class ModelRpcServer(rpyc.Service):
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
if not self.no_regex_fast_forward:
|
||||
req.fast_forward_map = self.fast_forward_cache.query(
|
||||
if not self.no_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
req.sampling_params.regex
|
||||
)
|
||||
|
||||
@@ -369,8 +369,8 @@ class ModelRpcServer(rpyc.Service):
|
||||
logger.debug(
|
||||
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. "
|
||||
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
)
|
||||
|
||||
new_batch = Batch.init_new(
|
||||
@@ -437,12 +437,12 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
if not self.no_regex_fast_forward:
|
||||
# check for fast forward
|
||||
fast_forward_reqs = batch.check_for_fast_forward()
|
||||
if not self.no_regex_jump_forward:
|
||||
# check for jump-forward
|
||||
jump_forward_reqs = batch.check_for_jump_forward()
|
||||
|
||||
# check for image fast forward
|
||||
for req in fast_forward_reqs:
|
||||
# check for image jump-forward
|
||||
for req in jump_forward_reqs:
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.input_ids,
|
||||
@@ -454,7 +454,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.image_size,
|
||||
)
|
||||
|
||||
self.forward_queue.extend(fast_forward_reqs)
|
||||
self.forward_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
|
||||
@@ -478,7 +478,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
output_tokens = []
|
||||
output_and_fast_forward_strs = []
|
||||
output_and_jump_forward_strs = []
|
||||
output_hit_stop_str = []
|
||||
output_skip_special_tokens = []
|
||||
output_meta_info = []
|
||||
@@ -502,7 +502,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_tokens.append(req.output_ids)
|
||||
output_and_fast_forward_strs.append(req.output_and_fast_forward_str)
|
||||
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
|
||||
output_hit_stop_str.append(req.hit_stop_str)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
@@ -523,7 +523,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_tokens,
|
||||
output_and_fast_forward_strs,
|
||||
output_and_jump_forward_strs,
|
||||
output_hit_stop_str,
|
||||
output_skip_special_tokens,
|
||||
output_meta_info,
|
||||
|
||||
@@ -25,7 +25,7 @@ class ServerArgs:
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
no_regex_fast_forward: bool = False
|
||||
no_regex_jump_forward: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
@@ -172,9 +172,9 @@ class ServerArgs:
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-regex-fast-forward",
|
||||
"--no-regex-jump-forward",
|
||||
action="store_true",
|
||||
help="Disable regex fast forward",
|
||||
help="Disable regex jump-forward",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user