Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189)
This commit is contained in:
@@ -15,10 +15,12 @@ class GenerateReqInput:
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
# The request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Whether return logprobs of the prompts
|
||||
# Whether to return logprobs
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# The start location of the prompt for return_logprob
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# Whether to detokenize tokens in logprobs
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -27,8 +27,12 @@ class Req:
|
||||
self.input_ids = input_ids
|
||||
self.output_ids = []
|
||||
|
||||
# for accumulated prompt tokens from jump forward
|
||||
self.orig_prompt_tokens = len(input_ids)
|
||||
# Since jump forward may retokenize the prompt with partial outputs,
|
||||
# we maintain the original prompt length to report the correct usage.
|
||||
self.prompt_tokens = len(input_ids)
|
||||
# The number of decoded tokens for token usage report. Note that
|
||||
# this does not include the jump forward tokens.
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
|
||||
# For vision input
|
||||
self.pixel_values = None
|
||||
|
||||
@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
# Check finish condition
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids = [next_token_ids[i]]
|
||||
req.check_finished()
|
||||
|
||||
@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_tok_id)
|
||||
req.check_finished()
|
||||
|
||||
@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
|
||||
# For the length of input_ids, which will be accumulated during jump-forward.
|
||||
# Use the original length of input_ids to calculate the token usage info.
|
||||
meta_info = {
|
||||
"prompt_tokens": req.orig_prompt_tokens,
|
||||
"prompt_tokens": req.prompt_tokens,
|
||||
"completion_tokens": len(req.input_ids)
|
||||
+ len(req.output_ids)
|
||||
- req.orig_prompt_tokens,
|
||||
- req.prompt_tokens,
|
||||
"completion_tokens_wo_jump_forward":
|
||||
req.completion_tokens_wo_jump_forward
|
||||
}
|
||||
|
||||
if req.return_logprob:
|
||||
meta_info["prompt_logprob"] = req.logprob
|
||||
meta_info["token_logprob"] = req.token_logprob
|
||||
|
||||
@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import alloc_usable_network_port, handle_port_init
|
||||
from sglang.srt.utils import handle_port_init
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -96,19 +96,25 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
async def stream_generator(obj):
|
||||
async def detokenize_logprob_tokens(token_logprobs):
|
||||
token_ids = [tid for tid, _ in token_logprobs]
|
||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
|
||||
|
||||
|
||||
async def stream_generator(obj: GenerateReqInput):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
if obj.return_logprob and obj.return_text_in_logprobs:
|
||||
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
||||
out["meta_info"]["token_logprob"]
|
||||
)
|
||||
yield out
|
||||
|
||||
|
||||
async def make_openai_style_logprobs(token_logprobs):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
# Detokenize
|
||||
token_ids = [tid for tid, _ in token_logprobs]
|
||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||
|
||||
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
||||
for token_text, token_logprob in token_logprobs:
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
if obj.return_logprob and obj.return_text_in_logprobs:
|
||||
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
||||
ret["meta_info"]["token_logprob"]
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
|
||||
"regex": request.regex,
|
||||
},
|
||||
return_logprob=request.logprobs is not None,
|
||||
return_text_in_logprobs=True,
|
||||
stream=request.stream,
|
||||
)
|
||||
adapted_request.post_init()
|
||||
@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
|
||||
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
ret = ret[0] if isinstance(ret, list) else ret
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
|
||||
Reference in New Issue
Block a user