Simplify mem state (#623)

This commit is contained in:
Mingyi
2024-07-15 02:01:09 -07:00
committed by GitHub
parent bae9541e4c
commit 5ac8b80677
7 changed files with 61 additions and 66 deletions

View File

@@ -98,7 +98,7 @@ class ModelTpServer:
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
8192
16384
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
@@ -222,30 +222,29 @@ class ModelTpServer:
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
for _ in range(global_config.num_continue_decode_steps):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.running_batch.is_empty():
self.running_batch = None
@@ -344,7 +343,7 @@ class ModelTpServer:
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
@@ -358,7 +357,7 @@ class ModelTpServer:
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
@@ -366,7 +365,7 @@ class ModelTpServer:
req.image_offset += 1
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
< available_size
and (
req.extend_input_len + new_batch_input_tokens
@@ -378,7 +377,7 @@ class ModelTpServer:
available_size += delta
if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
< available_size
):
# Undo locking
@@ -389,7 +388,7 @@ class ModelTpServer:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens()
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
@@ -403,9 +402,6 @@ class ModelTpServer:
# Print stats
if self.tp_rank == 0:
running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
@@ -420,7 +416,7 @@ class ModelTpServer:
f"#new-token: {new_batch_input_tokens}, "
f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_req}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
)
# logger.debug(