Improve benchmark scripts & rename some scripts (#477)

This commit is contained in:
Lianmin Zheng
2024-05-26 12:51:45 -07:00
committed by GitHub
parent 2b605ab1d7
commit 55c1643627
10 changed files with 161 additions and 62 deletions

View File

@@ -183,13 +183,13 @@ class TiktokenTokenizer:
self.eos_token_id = tokenizer.eos_token
self.vocab_size = tokenizer.n_vocab
def encode(self, x):
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):

View File

@@ -66,6 +66,7 @@ class Req:
self.finish_reason = None
self.hit_stop_str = None
# Prefix info
self.extend_input_len = 0
self.prefix_indices = []
self.last_node = None
@@ -76,8 +77,8 @@ class Req:
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None
self.decode_token_logprobs = []
self.prefill_top_logprobs = None
self.decode_token_logprobs = []
self.decode_top_logprobs = []
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs

View File

@@ -91,26 +91,27 @@ class ModelRpcServer:
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = max(
self.model_config.context_len,
(
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token
self.max_total_num_tokens // 6
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
),
)
self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(
f"[rank={self.tp_rank}] "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
logger.info(f"[rank={self.tp_rank}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}, "
)
if self.tp_rank == 0:
@@ -125,9 +126,9 @@ class ModelRpcServer:
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,
self.max_num_running_seq,
self.max_prefill_num_token,
self.max_total_num_token,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
@@ -219,7 +220,7 @@ class ModelRpcServer:
# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_token - (
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
@@ -231,7 +232,7 @@ class ModelRpcServer:
logger.info(
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_token:.2f}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throuhgput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
@@ -248,10 +249,10 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!"
)
@@ -297,14 +298,14 @@ class ModelRpcServer:
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_token - 128 - len(req.origin_input_ids),
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
)
self.forward_queue.append(req)
def get_new_fill_batch(self):
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_num_running_seq
and len(self.running_batch.reqs) > self.max_running_requests
):
return None
@@ -360,7 +361,7 @@ class ModelRpcServer:
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_num_token
< self.max_prefill_tokens
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta

View File

@@ -301,19 +301,19 @@ class ModelRunner:
return max_num_token
def init_memory_pool(self, total_gpu_memory):
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if self.max_total_num_token <= 0:
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enought memory. " "Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256),
int(self.max_total_num_tokens / self.model_config.context_len * 256),
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_token,
self.max_total_num_tokens,
dtype=torch.float16,
head_num=self.model_config.num_key_value_heads // self.tp_size,
head_dim=self.model_config.head_dim,

View File

@@ -6,15 +6,15 @@ class Scheduler:
def __init__(
self,
schedule_heuristic,
max_running_seq,
max_prefill_num_token,
max_total_num_token,
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
self.schedule_heuristic = schedule_heuristic
self.max_running_seq = max_running_seq
self.max_prefill_num_token = max_prefill_num_token
self.max_total_num_token = max_total_num_token
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache
def get_priority_queue(self, forward_queue):

View File

@@ -24,7 +24,8 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None
max_prefill_tokens: Optional[int] = None
max_running_requests: Optional[int] = None
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
@@ -149,11 +150,17 @@ class ServerArgs:
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
)
parser.add_argument(
"--max-prefill-num-token",
"--max-prefill-tokens",
type=int,
default=ServerArgs.max_prefill_num_token,
default=ServerArgs.max_prefill_tokens,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
)
parser.add_argument(
"--max-running-requests",
type=int,
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
parser.add_argument(
"--schedule-heuristic",
type=str,

View File

@@ -88,6 +88,28 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
return pred
def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None):
import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
sample_request = sampler_pb2.SampleTextRequest(
prompt=prompt,
settings=sampler_pb2.SampleSettings(
max_len=max_tokens,
rng_seed=0,
temperature=max(temperature, 1e-7),
nucleus_p=1,
stop_strings=[stop],
),
)
stream = sampler.SampleText(sample_request)
response = "".join([x.text for x in stream])
return response
def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
@@ -228,6 +250,7 @@ def add_common_other_args_and_parse(parser):
"vllm",
"outlines",
"lightllm",
"xinfer",
"guidance",
"lmql",
"srt-raw",
@@ -248,6 +271,7 @@ def add_common_other_args_and_parse(parser):
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
"xinfer": 9988,
}
args.port = default_port.get(args.backend, None)
return args
@@ -283,6 +307,8 @@ def _get_call_generate(args):
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "xinfer":
return partial(call_generate_xinfer, url=f"{args.host}:{args.port}")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":