Support decode token logprobs (#130)
This commit is contained in:
@@ -14,28 +14,11 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||||
if not input_metadata.return_logprob:
|
last_index = None
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
|
||||||
last_hidden = hidden_states
|
|
||||||
else:
|
|
||||||
last_index = (
|
|
||||||
torch.cumsum(
|
|
||||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
|
||||||
dim=0,
|
|
||||||
dtype=torch.long,
|
|
||||||
)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
last_hidden = hidden_states[last_index]
|
|
||||||
hidden_states = None
|
|
||||||
|
|
||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
# Compute the last index (the first decode token) of each requeast
|
||||||
if self.tp_size > 1:
|
# if we are in prefill or extend mode.
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||||
last_logits = last_logits[:, : self.config.vocab_size]
|
|
||||||
return last_logits, (None, None)
|
|
||||||
else:
|
|
||||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
|
||||||
last_index = (
|
last_index = (
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||||
@@ -45,29 +28,54 @@ class LogitsProcessor(nn.Module):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not input_metadata.return_logprob:
|
||||||
|
# When logprob is not requested, only compute the last logits.
|
||||||
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
|
last_hidden = hidden_states
|
||||||
|
else:
|
||||||
|
last_hidden = hidden_states[last_index]
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
|
last_logits = torch.matmul(last_hidden, weight.T)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
|
last_logits = last_logits[:, : self.config.vocab_size]
|
||||||
|
return last_logits, (None, None, None)
|
||||||
|
else:
|
||||||
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
logits = torch.matmul(hidden_states, weight.T)
|
logits = torch.matmul(hidden_states, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
logits = tensor_model_parallel_all_gather(logits)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
logits = logits[:, : self.config.vocab_size]
|
logits = logits[:, : self.config.vocab_size]
|
||||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||||
|
|
||||||
logprobs = all_logprobs[
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
last_logits = logits
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
last_logprobs = all_logprobs
|
||||||
]
|
prefill_logprobs = normalized_logprobs = None
|
||||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
else:
|
||||||
|
# Compute the logprobs for the last token of each request.
|
||||||
|
last_logits = logits[last_index]
|
||||||
|
last_logprobs = all_logprobs[last_index]
|
||||||
|
|
||||||
start = input_metadata.extend_start_loc.clone()
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||||
end = start + input_metadata.extend_seq_lens - 2
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
prefill_logprobs = all_logprobs[
|
||||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
normalized_logprobs = sum_logp / (
|
]
|
||||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
|
||||||
)
|
|
||||||
|
|
||||||
last_logits = logits[last_index]
|
start = input_metadata.extend_start_loc.clone()
|
||||||
return last_logits, (logprobs, normalized_logprobs)
|
end = start + input_metadata.extend_seq_lens - 2
|
||||||
|
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||||
|
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||||
|
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
|
||||||
|
normalized_logprobs = sum_logp / (
|
||||||
|
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -99,3 +99,7 @@ class BatchStrOut:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FlushCacheReq:
|
class FlushCacheReq:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DetokenizeReqInput:
|
||||||
|
input_ids: List[int]
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class Req:
|
|||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
self.logprob = None
|
self.logprob = None
|
||||||
|
self.token_logprob = None
|
||||||
self.normalized_logprob = None
|
self.normalized_logprob = None
|
||||||
|
|
||||||
# For constrained decoding
|
# For constrained decoding
|
||||||
|
|||||||
@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logprobs = None
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
|
||||||
batch, ForwardMode.EXTEND, batch.return_logprob
|
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
|
||||||
)
|
)
|
||||||
# print("extend logits", logits)
|
if prefill_logprobs is not None:
|
||||||
if logprobs is not None:
|
logprobs = prefill_logprobs.cpu().tolist()
|
||||||
logprobs = logprobs.cpu().tolist()
|
|
||||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||||
|
|
||||||
next_token_ids, next_token_probs = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
else:
|
else:
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
logprobs = normalized_logprobs = None
|
logits = logprobs = normalized_logprobs = last_logprobs = None
|
||||||
|
|
||||||
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
|
reqs = batch.reqs
|
||||||
|
if last_logprobs is not None:
|
||||||
|
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
reqs = batch.reqs
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
req.output_ids = [next_token_ids[i]]
|
req.output_ids = [next_token_ids[i]]
|
||||||
@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
||||||
req.normalized_logprob = normalized_logprobs[i]
|
req.normalized_logprob = normalized_logprobs[i]
|
||||||
|
|
||||||
|
token_ids = req.input_ids + [next_token_ids[i]]
|
||||||
|
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
|
||||||
|
req.token_logprob = list(zip(token_ids, token_logprobs))
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
||||||
next_token_ids, next_token_probs = batch.sample(logits)
|
batch,
|
||||||
|
ForwardMode.DECODE,
|
||||||
|
batch.return_logprob,
|
||||||
|
)
|
||||||
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
for i in range(len(reqs)):
|
if last_logprobs is not None:
|
||||||
reqs[i].output_ids.append(next_token_ids[i])
|
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
|
||||||
reqs[i].check_finished()
|
|
||||||
|
# Check finish condition
|
||||||
|
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||||
|
req.output_ids.append(next_tok_id)
|
||||||
|
req.check_finished()
|
||||||
|
|
||||||
|
if last_logprobs is not None:
|
||||||
|
req.token_logprob.append((next_tok_id, last_logprobs[i]))
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
meta_info["prompt_logprob"] = req.logprob
|
meta_info["prompt_logprob"] = req.logprob
|
||||||
|
meta_info["token_logprob"] = req.token_logprob
|
||||||
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
output_finished.append(req.finished)
|
output_finished.append(req.finished)
|
||||||
|
|||||||
@@ -397,6 +397,7 @@ class ModelRunner:
|
|||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
out_cache_cont_start,
|
out_cache_cont_start,
|
||||||
out_cache_cont_end,
|
out_cache_cont_end,
|
||||||
|
return_logprob,
|
||||||
):
|
):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
@@ -409,10 +410,9 @@ class ModelRunner:
|
|||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
out_cache_cont_start=out_cache_cont_start,
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
out_cache_cont_end=out_cache_cont_end,
|
||||||
|
return_logprob=return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
|
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
0
|
|
||||||
]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(
|
def forward_extend_multi_modal(
|
||||||
@@ -460,8 +460,8 @@ class ModelRunner:
|
|||||||
"prefix_lens": batch.prefix_lens,
|
"prefix_lens": batch.prefix_lens,
|
||||||
"position_ids_offsets": batch.position_ids_offsets,
|
"position_ids_offsets": batch.position_ids_offsets,
|
||||||
"out_cache_loc": batch.out_cache_loc,
|
"out_cache_loc": batch.out_cache_loc,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
}
|
}
|
||||||
kwargs["return_logprob"] = return_logprob
|
|
||||||
return self.forward_extend_multi_modal(**kwargs)
|
return self.forward_extend_multi_modal(**kwargs)
|
||||||
else:
|
else:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@@ -471,6 +471,7 @@ class ModelRunner:
|
|||||||
"prefix_lens": batch.prefix_lens,
|
"prefix_lens": batch.prefix_lens,
|
||||||
"position_ids_offsets": batch.position_ids_offsets,
|
"position_ids_offsets": batch.position_ids_offsets,
|
||||||
"out_cache_loc": batch.out_cache_loc,
|
"out_cache_loc": batch.out_cache_loc,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
}
|
}
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
@@ -478,10 +479,8 @@ class ModelRunner:
|
|||||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
||||||
return self.forward_decode(**kwargs)
|
return self.forward_decode(**kwargs)
|
||||||
elif forward_mode == ForwardMode.EXTEND:
|
elif forward_mode == ForwardMode.EXTEND:
|
||||||
kwargs["return_logprob"] = return_logprob
|
|
||||||
return self.forward_extend(**kwargs)
|
return self.forward_extend(**kwargs)
|
||||||
elif forward_mode == ForwardMode.PREFILL:
|
elif forward_mode == ForwardMode.PREFILL:
|
||||||
kwargs["return_logprob"] = return_logprob
|
|
||||||
return self.forward_prefill(**kwargs)
|
return self.forward_prefill(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
|
DetokenizeReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -234,6 +235,10 @@ class TokenizerManager:
|
|||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
|
async def detokenize(self, obj: DetokenizeReqInput):
|
||||||
|
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
||||||
|
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
||||||
|
|
||||||
async def flush_cache(self):
|
async def flush_cache(self):
|
||||||
flush_cache_req = FlushCacheReq()
|
flush_cache_req = FlushCacheReq()
|
||||||
self.send_to_router.send_pyobj(flush_cache_req)
|
self.send_to_router.send_pyobj(flush_cache_req)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
|
||||||
from sglang.srt.managers.openai_protocol import (
|
from sglang.srt.managers.openai_protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
|
|||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
|
LogProbs,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.router.manager import start_router_process
|
from sglang.srt.managers.router.manager import start_router_process
|
||||||
@@ -97,6 +98,23 @@ async def stream_generator(obj):
|
|||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
|
||||||
|
async def make_openai_style_logprobs(token_logprobs):
|
||||||
|
ret_logprobs = LogProbs()
|
||||||
|
|
||||||
|
# Detokenize
|
||||||
|
token_ids = [tid for tid, _ in token_logprobs]
|
||||||
|
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||||
|
|
||||||
|
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
||||||
|
ret_logprobs.tokens.append(token_text)
|
||||||
|
ret_logprobs.token_logprobs.append(token_logprob)
|
||||||
|
|
||||||
|
# Not supported yet.
|
||||||
|
ret_logprobs.top_logprobs.append({})
|
||||||
|
ret_logprobs.text_offset.append(-1)
|
||||||
|
return ret_logprobs
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
async def generate_request(obj: GenerateReqInput):
|
async def generate_request(obj: GenerateReqInput):
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
|
|||||||
"presence_penalty": request.presence_penalty,
|
"presence_penalty": request.presence_penalty,
|
||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
},
|
},
|
||||||
|
return_logprob=request.logprobs is not None,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
adapted_request.post_init()
|
adapted_request.post_init()
|
||||||
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
|
|||||||
|
|
||||||
async def gnerate_stream_resp():
|
async def gnerate_stream_resp():
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
|
n_prev_token = 0
|
||||||
async for content in stream_generator(adapted_request):
|
async for content in stream_generator(adapted_request):
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
|
if not stream_buffer: # The first chunk
|
||||||
|
if request.echo:
|
||||||
|
# Prepend prompt in response text.
|
||||||
|
text = request.prompt + text
|
||||||
|
else:
|
||||||
|
# Skip prompt tokens if echo is disabled.
|
||||||
|
n_prev_token = prompt_tokens
|
||||||
|
|
||||||
|
if request.logprobs is not None:
|
||||||
|
logprobs = await make_openai_style_logprobs(
|
||||||
|
content["meta_info"]["token_logprob"][n_prev_token:]
|
||||||
|
)
|
||||||
|
n_prev_token = len(content["meta_info"]["token_logprob"])
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
delta = text[len(stream_buffer) :]
|
delta = text[len(stream_buffer) :]
|
||||||
stream_buffer = text
|
stream_buffer = content["text"]
|
||||||
choice_data = CompletionResponseStreamChoice(
|
choice_data = CompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
text=delta,
|
text=delta,
|
||||||
logprobs=None,
|
logprobs=logprobs,
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
chunk = CompletionStreamResponse(
|
chunk = CompletionStreamResponse(
|
||||||
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
|
|||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
ret = await generate_request(adapted_request)
|
ret = await generate_request(adapted_request)
|
||||||
|
|
||||||
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||||
|
text = ret["text"]
|
||||||
|
token_logprob_pos = prompt_tokens
|
||||||
|
if request.echo:
|
||||||
|
token_logprob_pos = 0
|
||||||
|
text = request.prompt + text
|
||||||
|
else:
|
||||||
|
token_logprob_pos = prompt_tokens
|
||||||
|
|
||||||
|
logprobs = (
|
||||||
|
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
|
||||||
|
if request.logprobs is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
choice_data = CompletionResponseChoice(
|
choice_data = CompletionResponseChoice(
|
||||||
index=0,
|
index=0,
|
||||||
text=ret["text"],
|
text=text,
|
||||||
logprobs=None,
|
logprobs=logprobs,
|
||||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
|
||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
id=ret["meta_info"]["id"],
|
id=ret["meta_info"]["id"],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
|
|||||||
if not isinstance(m.content, str):
|
if not isinstance(m.content, str):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503,
|
status_code=503,
|
||||||
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
|
detail="Structured content requests not supported with "
|
||||||
|
"HuggingFace Chat Templates. "
|
||||||
|
"Make sure the server specifies a sglang chat template.",
|
||||||
)
|
)
|
||||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||||
request.messages, tokenize=False, add_generation_prompt=True
|
request.messages, tokenize=False, add_generation_prompt=True
|
||||||
|
|||||||
@@ -9,10 +9,24 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
def test_decode(url, return_logprob):
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
@@ -21,16 +35,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
response = requests.post(
|
test_decode(url, False)
|
||||||
url + "/generate",
|
test_decode(url, True)
|
||||||
json={
|
|
||||||
"text": "The capital of France is",
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 32,
|
|
||||||
},
|
|
||||||
# "return_logprob": True,
|
|
||||||
# "logprob_start_len": 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json())
|
|
||||||
|
|||||||
@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_decode_stream(url, return_logprob):
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
|
||||||
parser.add_argument("--port", type=int, default=30000)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
"text": "The capital of France is",
|
"text": "The capital of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 512,
|
"max_new_tokens": 128,
|
||||||
},
|
},
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
@@ -41,7 +34,29 @@ if __name__ == "__main__":
|
|||||||
if chunk == "data: [DONE]":
|
if chunk == "data: [DONE]":
|
||||||
break
|
break
|
||||||
data = json.loads(chunk[5:].strip("\n"))
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
output = data["text"].strip()
|
|
||||||
print(output[prev:], end="", flush=True)
|
if return_logprob:
|
||||||
prev = len(output)
|
assert data["meta_info"]["prompt_logprob"] is not None
|
||||||
|
assert data["meta_info"]["token_logprob"] is not None
|
||||||
|
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
||||||
|
if prev == 0: # Skip prompt logprobs
|
||||||
|
prev = data["meta_info"]["prompt_tokens"]
|
||||||
|
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
|
||||||
|
print(f"{token_txt}\t{logprob}", flush=True)
|
||||||
|
prev = len(data["meta_info"]["token_logprob"])
|
||||||
|
else:
|
||||||
|
output = data["text"].strip()
|
||||||
|
print(output[prev:], end="", flush=True)
|
||||||
|
prev = len(output)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=30000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
|
test_decode_stream(url, False)
|
||||||
|
test_decode_stream(url, True)
|
||||||
|
|||||||
@@ -18,15 +18,26 @@ import argparse
|
|||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
|
||||||
def test_completion(args):
|
def test_completion(args, echo, logprobs):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
prompt="The capital of France is",
|
prompt="The capital of France is",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
|
echo=echo,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
text = response.choices[0].text
|
||||||
print(response.choices[0].text)
|
print(response.choices[0].text)
|
||||||
|
if echo:
|
||||||
|
assert text.startswith("The capital of France is")
|
||||||
|
if logprobs:
|
||||||
|
assert response.choices[0].logprobs
|
||||||
|
if echo:
|
||||||
|
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||||
|
else:
|
||||||
|
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
@@ -34,7 +45,7 @@ def test_completion(args):
|
|||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
def test_completion_stream(args):
|
def test_completion_stream(args, echo, logprobs):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
@@ -42,9 +53,23 @@ def test_completion_stream(args):
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
echo=echo,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
first = True
|
||||||
for r in response:
|
for r in response:
|
||||||
print(r.choices[0].text, end="", flush=True)
|
if first:
|
||||||
|
if echo:
|
||||||
|
assert r.choices[0].text.startswith("The capital of France is")
|
||||||
|
first = False
|
||||||
|
if logprobs:
|
||||||
|
print(
|
||||||
|
f"{r.choices[0].text:12s}\t"
|
||||||
|
f"{r.choices[0].logprobs.token_logprobs}",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(r.choices[0].text, end="", flush=True)
|
||||||
assert r.id
|
assert r.id
|
||||||
assert r.usage.prompt_tokens > 0
|
assert r.usage.prompt_tokens > 0
|
||||||
assert r.usage.completion_tokens > 0
|
assert r.usage.completion_tokens > 0
|
||||||
@@ -135,8 +160,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
test_completion(args)
|
test_completion(args, echo=False, logprobs=False)
|
||||||
test_completion_stream(args)
|
test_completion(args, echo=True, logprobs=False)
|
||||||
|
test_completion(args, echo=False, logprobs=True)
|
||||||
|
test_completion(args, echo=True, logprobs=True)
|
||||||
|
test_completion_stream(args, echo=False, logprobs=False)
|
||||||
|
test_completion_stream(args, echo=True, logprobs=False)
|
||||||
|
test_completion_stream(args, echo=False, logprobs=True)
|
||||||
|
test_completion_stream(args, echo=True, logprobs=True)
|
||||||
test_chat_completion(args)
|
test_chat_completion(args)
|
||||||
test_chat_completion_stream(args)
|
test_chat_completion_stream(args)
|
||||||
if args.test_image:
|
if args.test_image:
|
||||||
|
|||||||
Reference in New Issue
Block a user