Simplify mem state (#623)
This commit is contained in:
@@ -98,7 +98,7 @@ class ModelTpServer:
|
||||
)
|
||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||
self.max_prefill_tokens = (
|
||||
8192
|
||||
16384
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
)
|
||||
@@ -222,30 +222,29 @@ class ModelTpServer:
|
||||
# Run decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(10):
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
self.num_generated_tokens += len(self.running_batch.reqs)
|
||||
self.forward_decode_batch(self.running_batch)
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
if self.decode_forward_ct % 40 == 0:
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
throughput = self.num_generated_tokens / (
|
||||
time.time() - self.last_stats_tic
|
||||
)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
throughput = self.num_generated_tokens / (
|
||||
time.time() - self.last_stats_tic
|
||||
)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
@@ -344,7 +343,7 @@ class ModelTpServer:
|
||||
if self.running_batch:
|
||||
available_size -= sum(
|
||||
[
|
||||
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
|
||||
for r in self.running_batch.reqs
|
||||
]
|
||||
)
|
||||
@@ -358,7 +357,7 @@ class ModelTpServer:
|
||||
req.prefix_indices = req.prefix_indices[:-delta]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += delta
|
||||
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
|
||||
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
||||
# Need at least one token to compute logits
|
||||
req.extend_input_len = 1
|
||||
req.prefix_indices = req.prefix_indices[:-1]
|
||||
@@ -366,7 +365,7 @@ class ModelTpServer:
|
||||
req.image_offset += 1
|
||||
|
||||
if (
|
||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
|
||||
< available_size
|
||||
and (
|
||||
req.extend_input_len + new_batch_input_tokens
|
||||
@@ -378,7 +377,7 @@ class ModelTpServer:
|
||||
available_size += delta
|
||||
|
||||
if not (
|
||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
|
||||
< available_size
|
||||
):
|
||||
# Undo locking
|
||||
@@ -389,7 +388,7 @@ class ModelTpServer:
|
||||
# Add this request to the running batch
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.max_new_tokens()
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
else:
|
||||
@@ -403,9 +402,6 @@ class ModelTpServer:
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
running_req = (
|
||||
0 if self.running_batch is None else len(self.running_batch.reqs)
|
||||
)
|
||||
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
||||
self.tree_cache_metrics["total"] += (
|
||||
hit_tokens + new_batch_input_tokens
|
||||
@@ -420,7 +416,7 @@ class ModelTpServer:
|
||||
f"#new-token: {new_batch_input_tokens}, "
|
||||
f"#cached-token: {hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"#running-req: {running_req}, "
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||
)
|
||||
# logger.debug(
|
||||
|
||||
Reference in New Issue
Block a user