Support decode token logprobs (#130)
This commit is contained in:
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
@@ -97,6 +98,23 @@ async def stream_generator(obj):
|
||||
yield out
|
||||
|
||||
|
||||
async def make_openai_style_logprobs(token_logprobs):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
# Detokenize
|
||||
token_ids = [tid for tid, _ in token_logprobs]
|
||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||
|
||||
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
# Not supported yet.
|
||||
ret_logprobs.top_logprobs.append({})
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
},
|
||||
return_logprob=request.logprobs is not None,
|
||||
stream=request.stream,
|
||||
)
|
||||
adapted_request.post_init()
|
||||
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
|
||||
|
||||
async def gnerate_stream_resp():
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
async for content in stream_generator(adapted_request):
|
||||
text = content["text"]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
# Prepend prompt in response text.
|
||||
text = request.prompt + text
|
||||
else:
|
||||
# Skip prompt tokens if echo is disabled.
|
||||
n_prev_token = prompt_tokens
|
||||
|
||||
if request.logprobs is not None:
|
||||
logprobs = await make_openai_style_logprobs(
|
||||
content["meta_info"]["token_logprob"][n_prev_token:]
|
||||
)
|
||||
n_prev_token = len(content["meta_info"]["token_logprob"])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = text
|
||||
stream_buffer = content["text"]
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=0,
|
||||
text=delta,
|
||||
logprobs=None,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
text = ret["text"]
|
||||
token_logprob_pos = prompt_tokens
|
||||
if request.echo:
|
||||
token_logprob_pos = 0
|
||||
text = request.prompt + text
|
||||
else:
|
||||
token_logprob_pos = prompt_tokens
|
||||
|
||||
logprobs = (
|
||||
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
|
||||
if request.logprobs is not None
|
||||
else None
|
||||
)
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=0,
|
||||
text=ret["text"],
|
||||
logprobs=None,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
response = CompletionResponse(
|
||||
id=ret["meta_info"]["id"],
|
||||
model=request.model,
|
||||
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
|
||||
if not isinstance(m.content, str):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
|
||||
detail="Structured content requests not supported with "
|
||||
"HuggingFace Chat Templates. "
|
||||
"Make sure the server specifies a sglang chat template.",
|
||||
)
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
|
||||
Reference in New Issue
Block a user