Optimize retract (#440)
This commit is contained in:
104
examples/usage/json_logprobs.py
Normal file
104
examples/usage/json_logprobs.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# NOTE: Currently this can only be run through HTTP requests.
|
||||||
|
import json
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from json_decode import character_regex
|
||||||
|
|
||||||
|
from sglang.utils import http_request
|
||||||
|
|
||||||
|
character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"]
|
||||||
|
|
||||||
|
base_url = "http://localhost:30000"
|
||||||
|
|
||||||
|
prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n"
|
||||||
|
|
||||||
|
|
||||||
|
def openai_api_request(name):
|
||||||
|
data = {
|
||||||
|
"model": "",
|
||||||
|
"prompt": name + prompt,
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 128,
|
||||||
|
"regex": character_regex,
|
||||||
|
"logprobs": 3,
|
||||||
|
}
|
||||||
|
res = http_request(base_url + "/v1/completions", json=data).json()
|
||||||
|
|
||||||
|
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
|
||||||
|
# fout.write(json.dumps(res, indent=4))
|
||||||
|
|
||||||
|
logprobs = res["choices"][0]["logprobs"]
|
||||||
|
usage = res["usage"]
|
||||||
|
assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"])
|
||||||
|
assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"])
|
||||||
|
assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def srt_api_request(name):
|
||||||
|
data = {
|
||||||
|
"text": name + prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 128,
|
||||||
|
"regex": character_regex,
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
"top_logprobs_num": 3,
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
res = http_request(base_url + "/generate", json=data).json()
|
||||||
|
|
||||||
|
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
|
||||||
|
# fout.write(json.dumps(res, indent=4))
|
||||||
|
|
||||||
|
meta_info = res["meta_info"]
|
||||||
|
assert len(meta_info["prefill_token_logprobs"]) == len(
|
||||||
|
meta_info["prefill_top_logprobs"]
|
||||||
|
)
|
||||||
|
assert len(meta_info["decode_token_logprobs"]) == len(
|
||||||
|
meta_info["decode_top_logprobs"]
|
||||||
|
)
|
||||||
|
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
|
||||||
|
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print(res):
|
||||||
|
meta_info = res["meta_info"]
|
||||||
|
|
||||||
|
print("\n\n", "=" * 30, "Prefill", "=" * 30)
|
||||||
|
for i in range(len(meta_info["prefill_token_logprobs"])):
|
||||||
|
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
|
||||||
|
top_ks = (
|
||||||
|
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
|
||||||
|
if meta_info["prefill_top_logprobs"][i]
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
for top_k in top_ks:
|
||||||
|
print(f"{top_k: <15}", end="")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("\n\n", "=" * 30, "Decode", "=" * 30)
|
||||||
|
for i in range(len(meta_info["decode_token_logprobs"])):
|
||||||
|
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
|
||||||
|
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
|
||||||
|
for top_k in top_ks:
|
||||||
|
print(f"{top_k: <15}", end="")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(res["text"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
ress = executor.map(srt_api_request, character_names)
|
||||||
|
|
||||||
|
for res in ress:
|
||||||
|
pretty_print(res)
|
||||||
|
|
||||||
|
openai_api_request("Hermione Granger")
|
||||||
@@ -28,5 +28,11 @@ class GlobalConfig:
|
|||||||
# Request dependency time due to network delay
|
# Request dependency time due to network delay
|
||||||
self.request_dependency_time = 0.03
|
self.request_dependency_time = 0.03
|
||||||
|
|
||||||
|
# New generation token ratio estimation
|
||||||
|
self.base_new_token_ratio = 0.4
|
||||||
|
self.base_min_new_token_ratio = 0.2
|
||||||
|
self.new_token_ratio_decay = 0.0001
|
||||||
|
self.new_token_ratio_recovery = 0.05
|
||||||
|
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|||||||
@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
|
|||||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
# NOTE: the GPU-CPU overhead can be reduced
|
# NOTE: the GPU-CPU overhead can be reduced
|
||||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||||
for i in range(len(extend_seq_lens_cpu)):
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||||
if extend_seq_lens_cpu[i] == 0:
|
if extend_seq_len == 0:
|
||||||
prefill_top_logprobs.append([])
|
prefill_top_logprobs.append([])
|
||||||
decode_top_logprobs.append([])
|
decode_top_logprobs.append([])
|
||||||
continue
|
continue
|
||||||
k = input_metadata.top_logprobs_nums[i]
|
k = input_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||||
vs_cpu = t.values.tolist()
|
vs_cpu = t.values.tolist()
|
||||||
ps_cpu = t.indices.tolist()
|
ps_cpu = t.indices.tolist()
|
||||||
prefill_top_logprobs.append(
|
prefill_top_logprobs.append(
|
||||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||||
)
|
)
|
||||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||||
pt += extend_seq_lens_cpu[i]
|
pt += extend_seq_len
|
||||||
|
|
||||||
return prefill_top_logprobs, decode_top_logprobs
|
return prefill_top_logprobs, decode_top_logprobs
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||||
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test():
|
||||||
all_logprobs = torch.tensor(
|
all_logprobs = torch.tensor(
|
||||||
# s s s
|
# s s s
|
||||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||||
@@ -173,3 +174,7 @@ if __name__ == "__main__":
|
|||||||
print("start", start)
|
print("start", start)
|
||||||
print("end", end)
|
print("end", end)
|
||||||
print("sum_logp", sum_logp)
|
print("sum_logp", sum_logp)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test()
|
||||||
|
|||||||
@@ -51,11 +51,6 @@ class DetokenizerManager:
|
|||||||
# Trim stop str
|
# Trim stop str
|
||||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||||
for i in range(len(output_strs)):
|
for i in range(len(output_strs)):
|
||||||
if recv_obj.hit_stop_str[i] is not None:
|
|
||||||
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
|
||||||
if pos != -1:
|
|
||||||
output_strs[i] = output_strs[i][:pos]
|
|
||||||
|
|
||||||
if len(output_tokens[i]) > 0:
|
if len(output_tokens[i]) > 0:
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
first_token = self.tokenizer.convert_ids_to_tokens(
|
||||||
int(output_tokens[i][0])
|
int(output_tokens[i][0])
|
||||||
@@ -65,9 +60,12 @@ class DetokenizerManager:
|
|||||||
if first_token.startswith("▁"):
|
if first_token.startswith("▁"):
|
||||||
output_strs[i] = " " + output_strs[i]
|
output_strs[i] = " " + output_strs[i]
|
||||||
|
|
||||||
output_strs[i] = (
|
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
||||||
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
|
|
||||||
)
|
if recv_obj.hit_stop_str[i] is not None:
|
||||||
|
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
||||||
|
if pos != -1:
|
||||||
|
output_strs[i] = output_strs[i][:pos]
|
||||||
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
BatchStrOut(
|
BatchStrOut(
|
||||||
|
|||||||
@@ -106,8 +106,8 @@ class TokenizedGenerateReqInput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
prev_output_strs : List[str]
|
||||||
output_tokens: List[List[int]]
|
output_tokens: List[List[int]]
|
||||||
output_and_jump_forward_strs: List[str]
|
|
||||||
hit_stop_str: List[Optional[str]]
|
hit_stop_str: List[Optional[str]]
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
spaces_between_special_tokens: List[bool]
|
spaces_between_special_tokens: List[bool]
|
||||||
|
|||||||
@@ -36,15 +36,15 @@ class FinishReason(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
def __init__(self, rid, input_text, input_ids):
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.input_text = input_text
|
self.origin_input_text = origin_input_text
|
||||||
self.input_ids = input_ids
|
self.origin_input_ids = origin_input_ids
|
||||||
|
self.origin_input_ids_unpadded = origin_input_ids # before image padding
|
||||||
|
self.prev_output_str = ""
|
||||||
|
self.prev_output_ids = []
|
||||||
self.output_ids = []
|
self.output_ids = []
|
||||||
|
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
|
||||||
# Since jump forward may retokenize the prompt with partial outputs,
|
|
||||||
# we maintain the original prompt length to report the correct usage.
|
|
||||||
self.prompt_tokens = len(input_ids)
|
|
||||||
|
|
||||||
# The number of decoded tokens for token usage report. Note that
|
# The number of decoded tokens for token usage report. Note that
|
||||||
# this does not include the jump forward tokens.
|
# this does not include the jump forward tokens.
|
||||||
@@ -76,15 +76,24 @@ class Req:
|
|||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
self.normalized_prompt_logprob = None
|
self.normalized_prompt_logprob = None
|
||||||
self.prefill_token_logprobs = None
|
self.prefill_token_logprobs = None
|
||||||
self.decode_token_logprobs = None
|
self.decode_token_logprobs = []
|
||||||
self.prefill_top_logprobs = None
|
self.prefill_top_logprobs = None
|
||||||
self.decode_top_logprobs = None
|
self.decode_top_logprobs = []
|
||||||
|
# The tokens is prefilled but need to be considered as decode tokens
|
||||||
|
# and should be updated for the decode logprobs
|
||||||
|
self.last_update_decode_tokens = 0
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm = None
|
||||||
self.regex_fsm_state = 0
|
self.regex_fsm_state = 0
|
||||||
self.jump_forward_map = None
|
self.jump_forward_map = None
|
||||||
self.output_and_jump_forward_str = ""
|
|
||||||
|
def partial_decode(self, ids):
|
||||||
|
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
||||||
|
first_token = (
|
||||||
|
first_token.decode() if isinstance(first_token, bytes) else first_token
|
||||||
|
)
|
||||||
|
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
|
||||||
|
|
||||||
def max_new_tokens(self):
|
def max_new_tokens(self):
|
||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
@@ -93,7 +102,10 @@ class Req:
|
|||||||
if self.finished:
|
if self.finished:
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
if (
|
||||||
|
len(self.prev_output_ids) + len(self.output_ids)
|
||||||
|
>= self.sampling_params.max_new_tokens
|
||||||
|
):
|
||||||
self.finished = True
|
self.finished = True
|
||||||
self.finish_reason = FinishReason.LENGTH
|
self.finish_reason = FinishReason.LENGTH
|
||||||
return
|
return
|
||||||
@@ -112,60 +124,66 @@ class Req:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for stop_str in self.sampling_params.stop_strs:
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
if stop_str in tail_str:
|
# FIXME: (minor) try incremental match in prev_output_str
|
||||||
|
if stop_str in tail_str or stop_str in self.prev_output_str:
|
||||||
self.finished = True
|
self.finished = True
|
||||||
self.finish_reason = FinishReason.STOP_STR
|
self.finish_reason = FinishReason.STOP_STR
|
||||||
self.hit_stop_str = stop_str
|
self.hit_stop_str = stop_str
|
||||||
return
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
|
||||||
# FIXME: This logic does not really solve the problem of determining whether
|
# FIXME: This logic does not really solve the problem of determining whether
|
||||||
# there should be a leading space.
|
# there should be a leading space.
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
cur_output_str = self.partial_decode(self.output_ids)
|
||||||
first_token = (
|
|
||||||
first_token.decode() if isinstance(first_token, bytes) else first_token
|
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
|
||||||
)
|
if self.origin_input_text is None:
|
||||||
if first_token.startswith("▁"):
|
# Recovering text can only use unpadded ids
|
||||||
old_output_str = " " + old_output_str
|
self.origin_input_text = self.tokenizer.decode(
|
||||||
if self.input_text is None:
|
self.origin_input_ids_unpadded
|
||||||
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
|
)
|
||||||
self.input_text = self.tokenizer.decode(self.input_ids)
|
|
||||||
new_input_string = (
|
all_text = (
|
||||||
self.input_text
|
self.origin_input_text
|
||||||
+ self.output_and_jump_forward_str
|
+ self.prev_output_str
|
||||||
+ old_output_str
|
+ cur_output_str
|
||||||
+ jump_forward_str
|
+ jump_forward_str
|
||||||
)
|
)
|
||||||
new_input_ids = self.tokenizer.encode(new_input_string)
|
all_ids = self.tokenizer.encode(all_text)
|
||||||
if self.pixel_values is not None:
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||||
# NOTE: This is a hack because the old input_ids contains the image padding
|
self.origin_input_ids = all_ids[:prompt_tokens]
|
||||||
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
|
self.origin_input_ids_unpadded = self.origin_input_ids
|
||||||
else:
|
# NOTE: the output ids may not strictly correspond to the output text
|
||||||
jump_forward_tokens_len = (
|
old_prev_output_ids = self.prev_output_ids
|
||||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
self.prev_output_ids = all_ids[prompt_tokens:]
|
||||||
)
|
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
|
||||||
|
self.output_ids = []
|
||||||
|
|
||||||
|
self.regex_fsm_state = next_state
|
||||||
|
|
||||||
|
if self.return_logprob:
|
||||||
|
# For fast-forward part's logprobs
|
||||||
|
k = 0
|
||||||
|
for i, old_id in enumerate(old_prev_output_ids):
|
||||||
|
if old_id == self.prev_output_ids[i]:
|
||||||
|
k = k + 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
||||||
|
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
||||||
|
self.logprob_start_len = prompt_tokens + k
|
||||||
|
self.last_update_decode_tokens = len(self.prev_output_ids) - k
|
||||||
|
|
||||||
# print("=" * 100)
|
# print("=" * 100)
|
||||||
# print(f"Catch jump forward:\n{jump_forward_str}")
|
# print(f"Catch jump forward:\n{jump_forward_str}")
|
||||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||||
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
||||||
|
|
||||||
self.input_ids = new_input_ids
|
|
||||||
self.output_ids = []
|
|
||||||
self.sampling_params.max_new_tokens = max(
|
|
||||||
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
|
|
||||||
)
|
|
||||||
self.regex_fsm_state = next_state
|
|
||||||
self.output_and_jump_forward_str = (
|
|
||||||
self.output_and_jump_forward_str + old_output_str + jump_forward_str
|
|
||||||
)
|
|
||||||
|
|
||||||
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
||||||
# print("*" * 100)
|
# print("*" * 100)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -336,6 +354,7 @@ class Batch:
|
|||||||
|
|
||||||
def retract_decode(self):
|
def retract_decode(self):
|
||||||
sorted_indices = [i for i in range(len(self.reqs))]
|
sorted_indices = [i for i in range(len(self.reqs))]
|
||||||
|
# TODO(lsyin): improve the priority of retraction
|
||||||
sorted_indices.sort(
|
sorted_indices.sort(
|
||||||
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
@@ -356,18 +375,27 @@ class Batch:
|
|||||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||||
self.token_to_kv_pool.dec_refs(token_indices)
|
self.token_to_kv_pool.dec_refs(token_indices)
|
||||||
|
|
||||||
|
# release the last node
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
|
cur_output_str = req.partial_decode(req.output_ids)
|
||||||
|
req.prev_output_str = req.prev_output_str + cur_output_str
|
||||||
|
req.prev_output_ids.extend(req.output_ids)
|
||||||
|
|
||||||
req.prefix_indices = None
|
req.prefix_indices = None
|
||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.extend_input_len = 0
|
req.extend_input_len = 0
|
||||||
req.output_ids = []
|
req.output_ids = []
|
||||||
req.regex_fsm_state = 0
|
|
||||||
|
# For incremental logprobs
|
||||||
|
req.last_update_decode_tokens = 0
|
||||||
|
req.logprob_start_len = 10**9
|
||||||
|
|
||||||
self.filter_batch(sorted_indices)
|
self.filter_batch(sorted_indices)
|
||||||
|
|
||||||
return retracted_reqs
|
return retracted_reqs
|
||||||
|
|
||||||
def check_for_jump_forward(self):
|
def check_for_jump_forward(self, model_runner):
|
||||||
jump_forward_reqs = []
|
jump_forward_reqs = []
|
||||||
filter_indices = [i for i in range(len(self.reqs))]
|
filter_indices = [i for i in range(len(self.reqs))]
|
||||||
|
|
||||||
@@ -397,6 +425,18 @@ class Batch:
|
|||||||
# jump-forward
|
# jump-forward
|
||||||
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
||||||
|
|
||||||
|
# re-applying image padding
|
||||||
|
if req.pixel_values is not None:
|
||||||
|
(
|
||||||
|
req.origin_input_ids,
|
||||||
|
req.image_offset,
|
||||||
|
) = model_runner.model.pad_input_ids(
|
||||||
|
req.origin_input_ids_unpadded,
|
||||||
|
req.pad_value,
|
||||||
|
req.pixel_values.shape,
|
||||||
|
req.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
jump_forward_reqs.append(req)
|
jump_forward_reqs.append(req)
|
||||||
filter_indices.remove(i)
|
filter_indices.remove(i)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import multiprocessing
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
@@ -16,6 +16,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from vllm.logger import logger as vllm_default_logger
|
from vllm.logger import logger as vllm_default_logger
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
@@ -106,7 +107,8 @@ class ModelRpcServer:
|
|||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
# Print info
|
# Print info
|
||||||
logger.info(f"[rank={self.tp_rank}] "
|
logger.info(
|
||||||
|
f"[rank={self.tp_rank}] "
|
||||||
f"max_total_num_token={self.max_total_num_token}, "
|
f"max_total_num_token={self.max_total_num_token}, "
|
||||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||||
f"context_len={self.model_config.context_len}, "
|
f"context_len={self.model_config.context_len}, "
|
||||||
@@ -151,9 +153,20 @@ class ModelRpcServer:
|
|||||||
self.jump_forward_cache = JumpForwardCache()
|
self.jump_forward_cache = JumpForwardCache()
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
assert (
|
||||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
server_args.schedule_conservativeness >= 0
|
||||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
), "Invalid schedule_conservativeness"
|
||||||
|
self.new_token_ratio = min(
|
||||||
|
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
self.min_new_token_ratio = min(
|
||||||
|
global_config.base_min_new_token_ratio
|
||||||
|
* server_args.schedule_conservativeness,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
if self.tp_size != 1:
|
if self.tp_size != 1:
|
||||||
@@ -256,8 +269,13 @@ class ModelRpcServer:
|
|||||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||||
]
|
]
|
||||||
req.image_size = recv_req.image_size
|
req.image_size = recv_req.image_size
|
||||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
req.origin_input_ids, req.image_offset = (
|
||||||
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
|
self.model_runner.model.pad_input_ids(
|
||||||
|
req.origin_input_ids_unpadded,
|
||||||
|
req.pad_value,
|
||||||
|
req.pixel_values.shape,
|
||||||
|
req.image_size,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
req.sampling_params = recv_req.sampling_params
|
req.sampling_params = recv_req.sampling_params
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
@@ -275,11 +293,11 @@ class ModelRpcServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
||||||
req.sampling_params.max_new_tokens = min(
|
req.sampling_params.max_new_tokens = min(
|
||||||
req.sampling_params.max_new_tokens,
|
req.sampling_params.max_new_tokens,
|
||||||
self.model_config.context_len - 1 - len(req.input_ids),
|
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
||||||
self.max_total_num_token - 128 - len(req.input_ids),
|
self.max_total_num_token - 128 - len(req.origin_input_ids),
|
||||||
)
|
)
|
||||||
self.forward_queue.append(req)
|
self.forward_queue.append(req)
|
||||||
|
|
||||||
@@ -292,6 +310,10 @@ class ModelRpcServer:
|
|||||||
|
|
||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
for req in self.forward_queue:
|
for req in self.forward_queue:
|
||||||
|
assert (
|
||||||
|
len(req.output_ids) == 0
|
||||||
|
), "The output ids should be empty when prefilling"
|
||||||
|
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||||
@@ -319,7 +341,7 @@ class ModelRpcServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for req in self.forward_queue:
|
for req in self.forward_queue:
|
||||||
if req.return_logprob:
|
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||||
# Need at least two tokens to compute normalized logprob
|
# Need at least two tokens to compute normalized logprob
|
||||||
if req.extend_input_len < 2:
|
if req.extend_input_len < 2:
|
||||||
delta = 2 - req.extend_input_len
|
delta = 2 - req.extend_input_len
|
||||||
@@ -441,28 +463,53 @@ class ModelRpcServer:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
if req.normalized_prompt_logprob is None:
|
||||||
|
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||||
|
|
||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
if req.prefill_token_logprobs is None:
|
||||||
req.prefill_token_logprobs = list(
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
zip(
|
req.prefill_token_logprobs = list(
|
||||||
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
zip(
|
||||||
req.input_ids[-req.extend_input_len + 1 :],
|
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||||
|
req.input_ids[-req.extend_input_len + 1 :],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if req.logprob_start_len == 0:
|
||||||
if req.logprob_start_len == 0:
|
req.prefill_token_logprobs = [
|
||||||
req.prefill_token_logprobs = [
|
(None, req.input_ids[0])
|
||||||
(None, req.input_ids[0])
|
] + req.prefill_token_logprobs
|
||||||
] + req.prefill_token_logprobs
|
|
||||||
req.decode_token_logprobs = [
|
if req.last_update_decode_tokens != 0:
|
||||||
|
req.decode_token_logprobs.extend(
|
||||||
|
list(
|
||||||
|
zip(
|
||||||
|
prefill_token_logprobs[
|
||||||
|
pt
|
||||||
|
+ req.extend_input_len
|
||||||
|
- req.last_update_decode_tokens : pt
|
||||||
|
+ req.extend_input_len
|
||||||
|
- 1
|
||||||
|
],
|
||||||
|
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
req.decode_token_logprobs.append(
|
||||||
(last_token_logprobs[i], next_token_ids[i])
|
(last_token_logprobs[i], next_token_ids[i])
|
||||||
]
|
)
|
||||||
|
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
if req.prefill_top_logprobs is None:
|
||||||
if req.logprob_start_len == 0:
|
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
if req.logprob_start_len == 0:
|
||||||
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||||
|
|
||||||
|
if req.last_update_decode_tokens != 0:
|
||||||
|
req.decode_top_logprobs.extend(
|
||||||
|
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||||
|
)
|
||||||
|
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||||
|
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
@@ -484,7 +531,7 @@ class ModelRpcServer:
|
|||||||
# check if decode out of memory
|
# check if decode out of memory
|
||||||
if not batch.check_decode_mem():
|
if not batch.check_decode_mem():
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
|
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
||||||
|
|
||||||
retracted_reqs = batch.retract_decode()
|
retracted_reqs = batch.retract_decode()
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -495,26 +542,13 @@ class ModelRpcServer:
|
|||||||
self.forward_queue.extend(retracted_reqs)
|
self.forward_queue.extend(retracted_reqs)
|
||||||
else:
|
else:
|
||||||
self.new_token_ratio = max(
|
self.new_token_ratio = max(
|
||||||
self.new_token_ratio - self.new_token_ratio_step[0],
|
self.new_token_ratio - self.new_token_ratio_decay,
|
||||||
self.min_new_token_ratio,
|
self.min_new_token_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_regex_jump_forward:
|
||||||
# check for jump-forward
|
# check for jump-forward
|
||||||
jump_forward_reqs = batch.check_for_jump_forward()
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||||
|
|
||||||
# check for image jump-forward
|
|
||||||
for req in jump_forward_reqs:
|
|
||||||
if req.pixel_values is not None:
|
|
||||||
(
|
|
||||||
req.input_ids,
|
|
||||||
req.image_offset,
|
|
||||||
) = self.model_runner.model.pad_input_ids(
|
|
||||||
req.input_ids,
|
|
||||||
req.pad_value,
|
|
||||||
req.pixel_values.shape,
|
|
||||||
req.image_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.forward_queue.extend(jump_forward_reqs)
|
self.forward_queue.extend(jump_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
@@ -557,8 +591,8 @@ class ModelRpcServer:
|
|||||||
|
|
||||||
def handle_finished_requests(self, batch: Batch):
|
def handle_finished_requests(self, batch: Batch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
|
prev_output_strs = []
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
output_and_jump_forward_strs = []
|
|
||||||
output_hit_stop_str = []
|
output_hit_stop_str = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
@@ -582,8 +616,8 @@ class ModelRpcServer:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
|
prev_output_strs.append(req.prev_output_str)
|
||||||
output_tokens.append(req.output_ids)
|
output_tokens.append(req.output_ids)
|
||||||
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
|
|
||||||
output_hit_stop_str.append(req.hit_stop_str)
|
output_hit_stop_str.append(req.hit_stop_str)
|
||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
@@ -593,10 +627,8 @@ class ModelRpcServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": req.prompt_tokens,
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
"completion_tokens": len(req.input_ids)
|
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
||||||
+ len(req.output_ids)
|
|
||||||
- req.prompt_tokens,
|
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
"finish_reason": FinishReason.to_str(req.finish_reason),
|
"finish_reason": FinishReason.to_str(req.finish_reason),
|
||||||
"hit_stop_str": req.hit_stop_str,
|
"hit_stop_str": req.hit_stop_str,
|
||||||
@@ -623,8 +655,8 @@ class ModelRpcServer:
|
|||||||
self.out_pyobjs.append(
|
self.out_pyobjs.append(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
|
prev_output_strs,
|
||||||
output_tokens,
|
output_tokens,
|
||||||
output_and_jump_forward_strs,
|
|
||||||
output_hit_stop_str,
|
output_hit_stop_str,
|
||||||
output_skip_special_tokens,
|
output_skip_special_tokens,
|
||||||
output_spaces_between_special_tokens,
|
output_spaces_between_special_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user