Support decode token logprobs (#130)

This commit is contained in:
Cody Yu
2024-02-06 12:24:55 -08:00
committed by GitHub
parent ee1df26a77
commit a7334aeea1
10 changed files with 233 additions and 96 deletions

View File

@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
from sglang.srt.managers.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
LogProbs,
UsageInfo,
)
from sglang.srt.managers.router.manager import start_router_process
@@ -97,6 +98,23 @@ async def stream_generator(obj):
yield out
async def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()
# Detokenize
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)
# Not supported yet.
ret_logprobs.top_logprobs.append({})
ret_logprobs.text_offset.append(-1)
return ret_logprobs
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
return_logprob=request.logprobs is not None,
stream=request.stream,
)
adapted_request.post_init()
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
async def gnerate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in stream_generator(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
else:
# Skip prompt tokens if echo is disabled.
n_prev_token = prompt_tokens
if request.logprobs is not None:
logprobs = await make_openai_style_logprobs(
content["meta_info"]["token_logprob"][n_prev_token:]
)
n_prev_token = len(content["meta_info"]["token_logprob"])
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = text
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
logprobs=logprobs,
finish_reason=None,
)
chunk = CompletionStreamResponse(
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
ret = await generate_request(adapted_request)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
token_logprob_pos = prompt_tokens
if request.echo:
token_logprob_pos = 0
text = request.prompt + text
else:
token_logprob_pos = prompt_tokens
logprobs = (
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
if request.logprobs is not None
else None
)
choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
text=text,
logprobs=logprobs,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
detail="Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True