Decode Incrementally (#517)
This commit is contained in:
@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
BaseFinishReason,
|
||||
Batch,
|
||||
FINISH_ABORT,
|
||||
ForwardMode,
|
||||
Req,
|
||||
)
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||
@@ -98,8 +104,11 @@ class ModelTpServer:
|
||||
else server_args.max_prefill_tokens
|
||||
),
|
||||
)
|
||||
self.max_running_requests = (self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None else server_args.max_running_requests)
|
||||
self.max_running_requests = (
|
||||
self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
@@ -314,10 +323,7 @@ class ModelTpServer:
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.forward_queue:
|
||||
assert (
|
||||
len(req.output_ids) == 0
|
||||
), "The output ids should be empty when prefilling"
|
||||
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_logprob:
|
||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||
@@ -464,7 +470,7 @@ class ModelTpServer:
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids = [next_token_ids[i]]
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
@@ -524,7 +530,7 @@ class ModelTpServer:
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
del_in_memory_pool=False,
|
||||
@@ -596,8 +602,9 @@ class ModelTpServer:
|
||||
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
prev_output_strs = []
|
||||
output_tokens = []
|
||||
decoded_texts = []
|
||||
surr_output_ids = []
|
||||
read_output_ids = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_meta_info = []
|
||||
@@ -620,8 +627,10 @@ class ModelTpServer:
|
||||
)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
prev_output_strs.append(req.prev_output_str)
|
||||
output_tokens.append(req.output_ids)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
||||
surr_output_ids.append(surr_ids)
|
||||
read_output_ids.append(read_ids)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
@@ -631,7 +640,7 @@ class ModelTpServer:
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": str(req.finished_reason),
|
||||
}
|
||||
@@ -657,8 +666,9 @@ class ModelTpServer:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
prev_output_strs,
|
||||
output_tokens,
|
||||
decoded_texts,
|
||||
surr_output_ids,
|
||||
read_output_ids,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
@@ -673,7 +683,7 @@ class ModelTpServer:
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
)
|
||||
@@ -790,4 +800,4 @@ class ModelTpClient:
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap("step")
|
||||
self.step = async_wrap("step")
|
||||
|
||||
Reference in New Issue
Block a user