Support decode token logprobs (#130)
This commit is contained in:
@@ -14,28 +14,11 @@ class LogitsProcessor(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||
if not input_metadata.return_logprob:
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
last_index = None
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, (None, None)
|
||||
else:
|
||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
||||
# Compute the last index (the first decode token) of each requeast
|
||||
# if we are in prefill or extend mode.
|
||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
@@ -45,29 +28,54 @@ class LogitsProcessor(nn.Module):
|
||||
- 1
|
||||
)
|
||||
|
||||
if not input_metadata.return_logprob:
|
||||
# When logprob is not requested, only compute the last logits.
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, (None, None, None)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
logits = logits[:, : self.config.vocab_size]
|
||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logits = logits
|
||||
last_logprobs = all_logprobs
|
||||
prefill_logprobs = normalized_logprobs = None
|
||||
else:
|
||||
# Compute the logprobs for the last token of each request.
|
||||
last_logits = logits[last_index]
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
prefill_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, (logprobs, normalized_logprobs)
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -99,3 +99,7 @@ class BatchStrOut:
|
||||
@dataclass
|
||||
class FlushCacheReq:
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class DetokenizeReqInput:
|
||||
input_ids: List[int]
|
||||
|
||||
@@ -48,6 +48,7 @@ class Req:
|
||||
self.last_node = None
|
||||
|
||||
self.logprob = None
|
||||
self.token_logprob = None
|
||||
self.normalized_logprob = None
|
||||
|
||||
# For constrained decoding
|
||||
|
||||
@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
logprobs = None
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
||||
batch, ForwardMode.EXTEND, batch.return_logprob
|
||||
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
|
||||
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
|
||||
)
|
||||
# print("extend logits", logits)
|
||||
if logprobs is not None:
|
||||
logprobs = logprobs.cpu().tolist()
|
||||
if prefill_logprobs is not None:
|
||||
logprobs = prefill_logprobs.cpu().tolist()
|
||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
logprobs = normalized_logprobs = None
|
||||
logits = logprobs = normalized_logprobs = last_logprobs = None
|
||||
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
if last_logprobs is not None:
|
||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||
|
||||
# Check finish condition
|
||||
reqs = batch.reqs
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req.output_ids = [next_token_ids[i]]
|
||||
@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
|
||||
if logprobs is not None:
|
||||
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
||||
req.normalized_logprob = normalized_logprobs[i]
|
||||
|
||||
token_ids = req.input_ids + [next_token_ids[i]]
|
||||
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
|
||||
req.token_logprob = list(zip(token_ids, token_logprobs))
|
||||
pt += req.extend_input_len
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward
|
||||
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
batch.return_logprob,
|
||||
)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# Check finish condition
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
for i in range(len(reqs)):
|
||||
reqs[i].output_ids.append(next_token_ids[i])
|
||||
reqs[i].check_finished()
|
||||
if last_logprobs is not None:
|
||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||
req.output_ids.append(next_tok_id)
|
||||
req.check_finished()
|
||||
|
||||
if last_logprobs is not None:
|
||||
req.token_logprob.append((next_tok_id, last_logprobs[i]))
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
}
|
||||
if req.return_logprob:
|
||||
meta_info["prompt_logprob"] = req.logprob
|
||||
meta_info["token_logprob"] = req.token_logprob
|
||||
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
||||
output_meta_info.append(meta_info)
|
||||
output_finished.append(req.finished)
|
||||
|
||||
@@ -397,6 +397,7 @@ class ModelRunner:
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
return_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
@@ -409,10 +410,9 @@ class ModelRunner:
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
return_logprob=return_logprob,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
|
||||
0
|
||||
]
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(
|
||||
@@ -460,8 +460,8 @@ class ModelRunner:
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
"return_logprob": return_logprob,
|
||||
}
|
||||
kwargs["return_logprob"] = return_logprob
|
||||
return self.forward_extend_multi_modal(**kwargs)
|
||||
else:
|
||||
kwargs = {
|
||||
@@ -471,6 +471,7 @@ class ModelRunner:
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
"return_logprob": return_logprob,
|
||||
}
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
@@ -478,10 +479,8 @@ class ModelRunner:
|
||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
||||
return self.forward_decode(**kwargs)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
kwargs["return_logprob"] = return_logprob
|
||||
return self.forward_extend(**kwargs)
|
||||
elif forward_mode == ForwardMode.PREFILL:
|
||||
kwargs["return_logprob"] = return_logprob
|
||||
return self.forward_prefill(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
DetokenizeReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -234,6 +235,10 @@ class TokenizerManager:
|
||||
|
||||
yield output_list
|
||||
|
||||
async def detokenize(self, obj: DetokenizeReqInput):
|
||||
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
||||
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
@@ -97,6 +98,23 @@ async def stream_generator(obj):
|
||||
yield out
|
||||
|
||||
|
||||
async def make_openai_style_logprobs(token_logprobs):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
# Detokenize
|
||||
token_ids = [tid for tid, _ in token_logprobs]
|
||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||
|
||||
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
# Not supported yet.
|
||||
ret_logprobs.top_logprobs.append({})
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
},
|
||||
return_logprob=request.logprobs is not None,
|
||||
stream=request.stream,
|
||||
)
|
||||
adapted_request.post_init()
|
||||
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
|
||||
|
||||
async def gnerate_stream_resp():
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
async for content in stream_generator(adapted_request):
|
||||
text = content["text"]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
# Prepend prompt in response text.
|
||||
text = request.prompt + text
|
||||
else:
|
||||
# Skip prompt tokens if echo is disabled.
|
||||
n_prev_token = prompt_tokens
|
||||
|
||||
if request.logprobs is not None:
|
||||
logprobs = await make_openai_style_logprobs(
|
||||
content["meta_info"]["token_logprob"][n_prev_token:]
|
||||
)
|
||||
n_prev_token = len(content["meta_info"]["token_logprob"])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = text
|
||||
stream_buffer = content["text"]
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=0,
|
||||
text=delta,
|
||||
logprobs=None,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
text = ret["text"]
|
||||
token_logprob_pos = prompt_tokens
|
||||
if request.echo:
|
||||
token_logprob_pos = 0
|
||||
text = request.prompt + text
|
||||
else:
|
||||
token_logprob_pos = prompt_tokens
|
||||
|
||||
logprobs = (
|
||||
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
|
||||
if request.logprobs is not None
|
||||
else None
|
||||
)
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=0,
|
||||
text=ret["text"],
|
||||
logprobs=None,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
response = CompletionResponse(
|
||||
id=ret["meta_info"]["id"],
|
||||
model=request.model,
|
||||
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
|
||||
if not isinstance(m.content, str):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
|
||||
detail="Structured content requests not supported with "
|
||||
"HuggingFace Chat Templates. "
|
||||
"Make sure the server specifies a sglang chat template.",
|
||||
)
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
|
||||
@@ -9,10 +9,24 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
def test_decode(url, return_logprob):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
@@ -21,16 +35,5 @@ if __name__ == "__main__":
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
# "return_logprob": True,
|
||||
# "logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
test_decode(url, False)
|
||||
test_decode(url, True)
|
||||
|
||||
@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
def test_decode_stream(url, return_logprob):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 512,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
@@ -41,7 +34,29 @@ if __name__ == "__main__":
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
|
||||
if return_logprob:
|
||||
assert data["meta_info"]["prompt_logprob"] is not None
|
||||
assert data["meta_info"]["token_logprob"] is not None
|
||||
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
||||
if prev == 0: # Skip prompt logprobs
|
||||
prev = data["meta_info"]["prompt_tokens"]
|
||||
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
|
||||
print(f"{token_txt}\t{logprob}", flush=True)
|
||||
prev = len(data["meta_info"]["token_logprob"])
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode_stream(url, False)
|
||||
test_decode_stream(url, True)
|
||||
|
||||
@@ -18,15 +18,26 @@ import argparse
|
||||
import openai
|
||||
|
||||
|
||||
def test_completion(args):
|
||||
def test_completion(args, echo, logprobs):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.completions.create(
|
||||
model="default",
|
||||
prompt="The capital of France is",
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
text = response.choices[0].text
|
||||
print(response.choices[0].text)
|
||||
if echo:
|
||||
assert text.startswith("The capital of France is")
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
if echo:
|
||||
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||
else:
|
||||
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
@@ -34,7 +45,7 @@ def test_completion(args):
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_completion_stream(args):
|
||||
def test_completion_stream(args, echo, logprobs):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.completions.create(
|
||||
model="default",
|
||||
@@ -42,9 +53,23 @@ def test_completion_stream(args):
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
stream=True,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
first = True
|
||||
for r in response:
|
||||
print(r.choices[0].text, end="", flush=True)
|
||||
if first:
|
||||
if echo:
|
||||
assert r.choices[0].text.startswith("The capital of France is")
|
||||
first = False
|
||||
if logprobs:
|
||||
print(
|
||||
f"{r.choices[0].text:12s}\t"
|
||||
f"{r.choices[0].logprobs.token_logprobs}",
|
||||
flush=True
|
||||
)
|
||||
else:
|
||||
print(r.choices[0].text, end="", flush=True)
|
||||
assert r.id
|
||||
assert r.usage.prompt_tokens > 0
|
||||
assert r.usage.completion_tokens > 0
|
||||
@@ -135,8 +160,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_completion(args)
|
||||
test_completion_stream(args)
|
||||
test_completion(args, echo=False, logprobs=False)
|
||||
test_completion(args, echo=True, logprobs=False)
|
||||
test_completion(args, echo=False, logprobs=True)
|
||||
test_completion(args, echo=True, logprobs=True)
|
||||
test_completion_stream(args, echo=False, logprobs=False)
|
||||
test_completion_stream(args, echo=True, logprobs=False)
|
||||
test_completion_stream(args, echo=False, logprobs=True)
|
||||
test_completion_stream(args, echo=True, logprobs=True)
|
||||
test_chat_completion(args)
|
||||
test_chat_completion_stream(args)
|
||||
if args.test_image:
|
||||
|
||||
Reference in New Issue
Block a user