Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189)
This commit is contained in:
@@ -15,10 +15,12 @@ class GenerateReqInput:
|
|||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
# The request id
|
# The request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Whether return logprobs of the prompts
|
# Whether to return logprobs
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# The start location of the prompt for return_logprob
|
# The start location of the prompt for return_logprob
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
|
# Whether to detokenize tokens in logprobs
|
||||||
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
|
|||||||
@@ -27,8 +27,12 @@ class Req:
|
|||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.output_ids = []
|
self.output_ids = []
|
||||||
|
|
||||||
# for accumulated prompt tokens from jump forward
|
# Since jump forward may retokenize the prompt with partial outputs,
|
||||||
self.orig_prompt_tokens = len(input_ids)
|
# we maintain the original prompt length to report the correct usage.
|
||||||
|
self.prompt_tokens = len(input_ids)
|
||||||
|
# The number of decoded tokens for token usage report. Note that
|
||||||
|
# this does not include the jump forward tokens.
|
||||||
|
self.completion_tokens_wo_jump_forward = 0
|
||||||
|
|
||||||
# For vision input
|
# For vision input
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
|
|||||||
@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
# Check finish condition
|
# Check finish condition
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids = [next_token_ids[i]]
|
req.output_ids = [next_token_ids[i]]
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||||
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_tok_id)
|
req.output_ids.append(next_tok_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# For the length of input_ids, which will be accumulated during jump-forward.
|
|
||||||
# Use the original length of input_ids to calculate the token usage info.
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": req.orig_prompt_tokens,
|
"prompt_tokens": req.prompt_tokens,
|
||||||
"completion_tokens": len(req.input_ids)
|
"completion_tokens": len(req.input_ids)
|
||||||
+ len(req.output_ids)
|
+ len(req.output_ids)
|
||||||
- req.orig_prompt_tokens,
|
- req.prompt_tokens,
|
||||||
|
"completion_tokens_wo_jump_forward":
|
||||||
|
req.completion_tokens_wo_jump_forward
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
meta_info["prompt_logprob"] = req.logprob
|
meta_info["prompt_logprob"] = req.logprob
|
||||||
meta_info["token_logprob"] = req.token_logprob
|
meta_info["token_logprob"] = req.token_logprob
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
|
|||||||
from sglang.srt.managers.router.manager import start_router_process
|
from sglang.srt.managers.router.manager import start_router_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import alloc_usable_network_port, handle_port_init
|
from sglang.srt.utils import handle_port_init
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -96,19 +96,25 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(obj):
|
async def detokenize_logprob_tokens(token_logprobs):
|
||||||
|
token_ids = [tid for tid, _ in token_logprobs]
|
||||||
|
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||||
|
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_generator(obj: GenerateReqInput):
|
||||||
async for out in tokenizer_manager.generate_request(obj):
|
async for out in tokenizer_manager.generate_request(obj):
|
||||||
|
if obj.return_logprob and obj.return_text_in_logprobs:
|
||||||
|
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
||||||
|
out["meta_info"]["token_logprob"]
|
||||||
|
)
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
|
||||||
async def make_openai_style_logprobs(token_logprobs):
|
async def make_openai_style_logprobs(token_logprobs):
|
||||||
ret_logprobs = LogProbs()
|
ret_logprobs = LogProbs()
|
||||||
|
|
||||||
# Detokenize
|
for token_text, token_logprob in token_logprobs:
|
||||||
token_ids = [tid for tid, _ in token_logprobs]
|
|
||||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
|
||||||
|
|
||||||
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
|
||||||
ret_logprobs.tokens.append(token_text)
|
ret_logprobs.tokens.append(token_text)
|
||||||
ret_logprobs.token_logprobs.append(token_logprob)
|
ret_logprobs.token_logprobs.append(token_logprob)
|
||||||
|
|
||||||
@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
|
|||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||||
|
|
||||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||||
|
if obj.return_logprob and obj.return_text_in_logprobs:
|
||||||
|
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
||||||
|
ret["meta_info"]["token_logprob"]
|
||||||
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
|
|||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
},
|
},
|
||||||
return_logprob=request.logprobs is not None,
|
return_logprob=request.logprobs is not None,
|
||||||
|
return_text_in_logprobs=True,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
adapted_request.post_init()
|
adapted_request.post_init()
|
||||||
@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
|
|||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
ret = await generate_request(adapted_request)
|
ret = await generate_request(adapted_request)
|
||||||
|
ret = ret[0] if isinstance(ret, list) else ret
|
||||||
|
|
||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||||
|
|||||||
Reference in New Issue
Block a user