Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
|
||||
return logprobs
|
||||
|
||||
|
||||
def get_token_ids_logprobs(logits, token_ids):
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
del logits
|
||||
logprobs = logprobs[..., token_ids]
|
||||
return logprobs
|
||||
|
||||
|
||||
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers.util import is_sentence_transformer_model
|
||||
@@ -84,8 +91,13 @@ class ModelOutput:
|
||||
output_ids: List[int] = None
|
||||
top_input_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprob_idx: List[List[int]] = None
|
||||
embed_logits: List[torch.Tensor] = None
|
||||
scores: List[float] = None
|
||||
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
||||
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
||||
token_ids_input_logprobs: List[torch.Tensor] = None
|
||||
token_ids_output_logprobs: List[torch.Tensor] = None
|
||||
|
||||
|
||||
class HFRunner:
|
||||
@@ -157,7 +169,7 @@ class HFRunner:
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
@@ -165,16 +177,16 @@ class HFRunner:
|
||||
if self.model_type == "generation":
|
||||
out_queue.put(
|
||||
self.forward_generation_raw(
|
||||
base_model=self.base_model,
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
base_model=self.base_model,
|
||||
tokenizer=self.tokenizer,
|
||||
lora_paths=lora_paths,
|
||||
torch_dtype=torch_dtype,
|
||||
output_str_only=self.output_str_only,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
)
|
||||
|
||||
elif self.model_type == "embedding":
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
@@ -199,10 +211,11 @@ class HFRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -218,17 +231,24 @@ class HFRunner:
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
base_model,
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens: int,
|
||||
tokenizer,
|
||||
lora_paths,
|
||||
torch_dtype: torch.dtype,
|
||||
output_str_only: bool,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
output_str_only: bool = False,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
) -> ModelOutput:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs = []
|
||||
token_ids_output_logprobs = []
|
||||
else:
|
||||
token_ids_input_logprobs = token_ids_output_logprobs = None
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
||||
@@ -275,18 +295,33 @@ class HFRunner:
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_output_logprobs.append(
|
||||
[
|
||||
get_token_ids_logprobs(
|
||||
logits[0], token_ids_logprob
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs.append(
|
||||
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
token_ids_input_logprobs=token_ids_input_logprobs,
|
||||
token_ids_output_logprobs=token_ids_output_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@@ -303,11 +338,31 @@ class SRTRunner:
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
chunked_prefill_size: Optional[int] = None,
|
||||
dp_size: int = 1,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
enable_ep_moe: bool = False,
|
||||
mem_fraction_static: float = 0.65,
|
||||
trust_remote_code: bool = False,
|
||||
speculative_draft_model_path: Optional[str] = None,
|
||||
speculative_algorithm: Optional[str] = None,
|
||||
speculative_num_steps: Optional[int] = None,
|
||||
speculative_eagle_topk: Optional[int] = None,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
disable_overlap_schedule: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
enable_dp_attention = dp_size > 1
|
||||
|
||||
spec_kwargs = {}
|
||||
if speculative_draft_model_path:
|
||||
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
||||
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
||||
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
||||
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
||||
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
|
||||
|
||||
self.engine = Engine(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
@@ -321,21 +376,41 @@ class SRTRunner:
|
||||
lora_backend=lora_backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
enable_dp_attention=enable_dp_attention,
|
||||
dp_size=dp_size,
|
||||
tokenizer_path=tokenizer_path,
|
||||
enable_ep_moe=enable_ep_moe,
|
||||
disable_overlap_schedule=disable_overlap_schedule,
|
||||
cuda_graph_max_bs=4,
|
||||
**spec_kwargs,
|
||||
)
|
||||
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
if tokenizer_path is None:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
logprob_start_len: int = 0,
|
||||
top_k: Optional[int] = None,
|
||||
token_ids_logprob: Optional[List[int]] = None,
|
||||
):
|
||||
if self.is_generation:
|
||||
return self.forward_generation_raw(
|
||||
engine=self.engine,
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_k=top_k,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -358,10 +433,10 @@ class SRTRunner:
|
||||
"""
|
||||
if self.is_generation:
|
||||
return self.batch_forward_generation_raw(
|
||||
engine=self.engine,
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -381,24 +456,43 @@ class SRTRunner:
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
engine: Engine,
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
lora_paths,
|
||||
engine,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
logprob_start_len: int = 0,
|
||||
top_k: Optional[int] = None,
|
||||
token_ids_logprob: Optional[List[int]] = None,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
output_ids = []
|
||||
# Input logprobs. Note that the last item in input logprob is equivalent to
|
||||
# the first item in the output logprob.
|
||||
top_input_logprobs = []
|
||||
input_token_logprobs_lst = []
|
||||
top_output_logprobs = []
|
||||
output_token_logprobs_lst = []
|
||||
top_output_logprob_idx = []
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs = []
|
||||
token_ids_output_logprobs = []
|
||||
else:
|
||||
token_ids_input_logprobs = token_ids_output_logprobs = None
|
||||
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
if top_k:
|
||||
sampling_params["top_k"] = top_k
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
text = response["text"]
|
||||
|
||||
@@ -408,12 +502,36 @@ class SRTRunner:
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
# output_ids.append(response["output_ids"])
|
||||
|
||||
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
|
||||
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
|
||||
# print(i, input_token_logprobs)
|
||||
# print(i, output_token_logprobs)
|
||||
logprobs = response["meta_info"]["input_top_logprobs"]
|
||||
if token_ids_logprob is not None:
|
||||
input_token_ids_logprobs = response["meta_info"][
|
||||
"input_token_ids_logprobs"
|
||||
][1:]
|
||||
else:
|
||||
input_token_ids_logprobs = None
|
||||
|
||||
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
|
||||
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
|
||||
assert len(logprobs) == num_prompt_tokens - logprob_start_len
|
||||
|
||||
# The first token logprob has no meaning in sglang.
|
||||
input_token_logprobs = input_token_logprobs[1:]
|
||||
logprobs = logprobs[1:]
|
||||
assert len(input_token_logprobs) == len(logprobs)
|
||||
|
||||
input_token_logprobs_lst.append(
|
||||
input_token_logprobs + [output_token_logprobs[0]]
|
||||
)
|
||||
output_token_logprobs_lst.append(output_token_logprobs)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
[[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
@@ -429,11 +547,41 @@ class SRTRunner:
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
top_output_logprob_idx.append(
|
||||
[
|
||||
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs.append(
|
||||
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"][
|
||||
"output_token_ids_logprobs"
|
||||
][0]
|
||||
]
|
||||
]
|
||||
)
|
||||
token_ids_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x]
|
||||
for x in response["meta_info"]["output_token_ids_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
output_ids=output_ids,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
input_token_logprobs_lst=input_token_logprobs_lst,
|
||||
output_token_logprobs_lst=output_token_logprobs_lst,
|
||||
top_output_logprob_idx=top_output_logprob_idx,
|
||||
token_ids_input_logprobs=token_ids_input_logprobs,
|
||||
token_ids_output_logprobs=token_ids_output_logprobs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
88
python/sglang/test/send_one.py
Normal file
88
python/sglang/test/send_one.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Run one test prompt.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.test.send_one
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def send_one_prompt(args):
|
||||
if args.image:
|
||||
args.prompt = (
|
||||
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
||||
)
|
||||
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:30000/generate",
|
||||
json={
|
||||
"text": args.prompt,
|
||||
"image_data": image_data,
|
||||
"sampling_params": {
|
||||
"temperature": args.temperature,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
},
|
||||
stream=args.stream,
|
||||
)
|
||||
|
||||
if args.stream:
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
ret = json.loads(chunk[5:].strip("\n"))
|
||||
else:
|
||||
ret = response.json()
|
||||
|
||||
latency = ret["meta_info"]["e2e_latency"]
|
||||
|
||||
if "spec_verify_ct" in ret["meta_info"]:
|
||||
acc_length = (
|
||||
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
|
||||
)
|
||||
else:
|
||||
acc_length = 1.0
|
||||
|
||||
speed = ret["meta_info"]["completion_tokens"] / latency
|
||||
|
||||
print(ret["text"])
|
||||
print()
|
||||
print(f"{acc_length=:.2f}")
|
||||
print(f"{speed=:.2f} token/s")
|
||||
|
||||
return acc_length, speed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=512)
|
||||
parser.add_argument("--frequency-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--presence-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
send_one_prompt(args)
|
||||
@@ -8,10 +8,11 @@ import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -408,26 +409,49 @@ def popen_launch_server(
|
||||
other_args: list[str] = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
if pd_seperated:
|
||||
command = "sglang.launch_pd_server"
|
||||
else:
|
||||
command = "sglang.launch_server"
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
command,
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
*other_args,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
|
||||
if pd_seperated:
|
||||
command.extend(
|
||||
[
|
||||
"--lb-host",
|
||||
host,
|
||||
"--lb-port",
|
||||
port,
|
||||
]
|
||||
)
|
||||
else:
|
||||
command.extend(
|
||||
[
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
]
|
||||
)
|
||||
|
||||
if api_key:
|
||||
command += ["--api-key", api_key]
|
||||
|
||||
print(f"command={' '.join(command)}")
|
||||
|
||||
if return_stdout_stderr:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
@@ -456,6 +480,8 @@ def popen_launch_server(
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
kill_process_tree(process.pid)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
success = True
|
||||
|
||||
for filename in files:
|
||||
global process
|
||||
process = None
|
||||
|
||||
def run_one_file(filename):
|
||||
nonlocal process
|
||||
|
||||
filename = os.path.join(os.getcwd(), filename)
|
||||
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
|
||||
process = subprocess.Popen(
|
||||
@@ -534,11 +562,14 @@ def get_benchmark_args(
|
||||
dataset_path="",
|
||||
tokenizer="",
|
||||
num_prompts=500,
|
||||
sharegpt_output_len=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
sharegpt_context_len=None,
|
||||
request_rate=float("inf"),
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
backend="sglang",
|
||||
@@ -550,8 +581,8 @@ def get_benchmark_args(
|
||||
model=None,
|
||||
tokenizer=tokenizer,
|
||||
num_prompts=num_prompts,
|
||||
sharegpt_output_len=None,
|
||||
sharegpt_context_len=None,
|
||||
sharegpt_output_len=sharegpt_output_len,
|
||||
sharegpt_context_len=sharegpt_context_len,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
random_range_ratio=0.0,
|
||||
@@ -567,6 +598,8 @@ def get_benchmark_args(
|
||||
apply_chat_template=False,
|
||||
profile=None,
|
||||
lora_name=None,
|
||||
prompt_suffix="",
|
||||
pd_seperated=pd_seperated,
|
||||
)
|
||||
|
||||
|
||||
@@ -580,6 +613,7 @@ def run_bench_serving(
|
||||
tokenizer=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
sharegpt_context_len=None,
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
need_warmup=False,
|
||||
@@ -602,6 +636,7 @@ def run_bench_serving(
|
||||
num_prompts=num_prompts,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
sharegpt_context_len=sharegpt_context_len,
|
||||
request_rate=request_rate,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=disable_ignore_eos,
|
||||
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
|
||||
other_server_args,
|
||||
benchmark_args,
|
||||
need_warmup=False,
|
||||
pd_seperated=False,
|
||||
):
|
||||
# Launch the server
|
||||
process = popen_launch_server(
|
||||
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_server_args,
|
||||
pd_seperated=pd_seperated,
|
||||
)
|
||||
|
||||
# run benchmark for all
|
||||
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
|
||||
"128",
|
||||
"--output",
|
||||
"8",
|
||||
*other_args,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
|
||||
stdout = open(STDOUT_FILENAME, "w")
|
||||
stderr = open(STDERR_FILENAME, "w")
|
||||
process = subprocess.Popen(
|
||||
command, stdout=stdout, stderr=stderr, env=env, text=True
|
||||
command, stdout=stdout, stderr=stdout, env=env, text=True
|
||||
)
|
||||
|
||||
# Launch a thread to stream the output
|
||||
@@ -914,3 +951,78 @@ def run_mulit_request_test(
|
||||
def write_github_step_summary(content):
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
) = arg
|
||||
input_ids = list(range(input_len))
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
if return_logprob:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||
|
||||
if top_logprobs_num:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
|
||||
|
||||
for i in range(output_len):
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||
top_logprobs_num,
|
||||
)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||
if temperature == 0:
|
||||
rank = 0
|
||||
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
|
||||
try:
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][rank],
|
||||
)
|
||||
break
|
||||
except AssertionError:
|
||||
# There's a tie. Allow the second item in this case.
|
||||
if (
|
||||
res["meta_info"]["output_top_logprobs"][i][rank][0]
|
||||
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
|
||||
0
|
||||
]
|
||||
):
|
||||
rank += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user