Decode Incrementally (#517)

This commit is contained in:
Liangsheng Yin
2024-06-12 14:39:12 +08:00
committed by GitHub
parent 111991fe23
commit 9c902b1954
8 changed files with 345 additions and 135 deletions

View File

@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
from sglang.srt.managers.controller.infer_batch import (
BaseFinishReason,
Batch,
FINISH_ABORT,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
@@ -98,8 +104,11 @@ class ModelTpServer:
else server_args.max_prefill_tokens
),
)
self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests)
self.max_running_requests = (
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
@@ -314,10 +323,7 @@ class ModelTpServer:
# Compute matched prefix length
for req in self.forward_queue:
assert (
len(req.output_ids) == 0
), "The output ids should be empty when prefilling"
req.input_ids = req.origin_input_ids + req.prev_output_ids
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
@@ -464,7 +470,7 @@ class ModelTpServer:
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]]
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.return_logprob:
@@ -524,7 +530,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
@@ -596,8 +602,9 @@ class ModelTpServer:
def handle_finished_requests(self, batch: Batch):
output_rids = []
prev_output_strs = []
output_tokens = []
decoded_texts = []
surr_output_ids = []
read_output_ids = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
@@ -620,8 +627,10 @@ class ModelTpServer:
)
):
output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids)
decoded_texts.append(req.decoded_text)
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
surr_output_ids.append(surr_ids)
read_output_ids.append(read_ids)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
@@ -631,7 +640,7 @@ class ModelTpServer:
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
@@ -657,8 +666,9 @@ class ModelTpServer:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
prev_output_strs,
output_tokens,
decoded_texts,
surr_output_ids,
read_output_ids,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
@@ -673,7 +683,7 @@ class ModelTpServer:
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
@@ -790,4 +800,4 @@ class ModelTpClient:
return _func
self.step = async_wrap("step")
self.step = async_wrap("step")