Simplify mem state (#623)
This commit is contained in:
@@ -297,7 +297,8 @@ def main(args: argparse.Namespace):
|
|||||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||||
|
|
||||||
# Compute the statistics.
|
# Compute the statistics.
|
||||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
latencies = [latency for _, _, latency in REQUEST_LATENCY]
|
||||||
|
avg_latency = np.mean(latencies)
|
||||||
avg_per_token_latency = np.mean(
|
avg_per_token_latency = np.mean(
|
||||||
[
|
[
|
||||||
latency / (prompt_len + output_len)
|
latency / (prompt_len + output_len)
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ class GlobalConfig:
|
|||||||
# This can improve the speed for large batch sizes during prefill.
|
# This can improve the speed for large batch sizes during prefill.
|
||||||
self.layer_sync_threshold = 8192
|
self.layer_sync_threshold = 8192
|
||||||
|
|
||||||
# Runtime constants: Flashinfer
|
# Runtime constants: others
|
||||||
|
self.num_continue_decode_steps = 10
|
||||||
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
||||||
|
|
||||||
# Output tokenization configs
|
# Output tokenization configs
|
||||||
|
|||||||
@@ -174,9 +174,6 @@ class Req:
|
|||||||
|
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
def max_new_tokens(self):
|
|
||||||
return self.sampling_params.max_new_tokens
|
|
||||||
|
|
||||||
def check_finished(self):
|
def check_finished(self):
|
||||||
if self.finished():
|
if self.finished():
|
||||||
return
|
return
|
||||||
@@ -352,7 +349,7 @@ class Batch:
|
|||||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
|
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
@@ -422,7 +419,7 @@ class Batch:
|
|||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
|
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||||
|
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
@@ -453,7 +450,7 @@ class Batch:
|
|||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req_pool_indices_cpu[idx]
|
req_pool_indices_cpu[idx]
|
||||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||||
self.token_to_kv_pool.dec_refs(token_indices)
|
self.token_to_kv_pool.free(token_indices)
|
||||||
|
|
||||||
# release the last node
|
# release the last node
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
@@ -596,8 +593,7 @@ class Batch:
|
|||||||
"logit_bias",
|
"logit_bias",
|
||||||
]:
|
]:
|
||||||
self_val = getattr(self, item, None)
|
self_val = getattr(self, item, None)
|
||||||
# logit_bias can be None
|
if self_val is not None: # logit_bias can be None
|
||||||
if self_val is not None:
|
|
||||||
setattr(self, item, self_val[new_indices])
|
setattr(self, item, self_val[new_indices])
|
||||||
|
|
||||||
def merge(self, other: "Batch"):
|
def merge(self, other: "Batch"):
|
||||||
|
|||||||
@@ -82,12 +82,12 @@ class RadixCache:
|
|||||||
|
|
||||||
if self.disable:
|
if self.disable:
|
||||||
if del_in_memory_pool:
|
if del_in_memory_pool:
|
||||||
self.token_to_kv_pool.dec_refs(indices)
|
self.token_to_kv_pool.free(indices)
|
||||||
else:
|
else:
|
||||||
return torch.tensor([], dtype=torch.int64), self.root_node
|
return torch.tensor([], dtype=torch.int64), self.root_node
|
||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
|
||||||
|
|
||||||
if del_in_memory_pool:
|
if del_in_memory_pool:
|
||||||
self.req_to_token_pool.free(req_pool_idx)
|
self.req_to_token_pool.free(req_pool_idx)
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
|
|||||||
max_total_num_tokens,
|
max_total_num_tokens,
|
||||||
tree_cache,
|
tree_cache,
|
||||||
):
|
):
|
||||||
|
if tree_cache.disable and schedule_heuristic == "lpm":
|
||||||
|
# LMP is not meaningless when tree cache is disabled.
|
||||||
|
schedule_heuristic = "fcfs"
|
||||||
|
|
||||||
self.schedule_heuristic = schedule_heuristic
|
self.schedule_heuristic = schedule_heuristic
|
||||||
self.max_running_seqs = max_running_seqs
|
self.max_running_seqs = max_running_seqs
|
||||||
self.max_prefill_num_tokens = max_prefill_num_tokens
|
self.max_prefill_num_tokens = max_prefill_num_tokens
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||||
self.max_prefill_tokens = (
|
self.max_prefill_tokens = (
|
||||||
8192
|
16384
|
||||||
if server_args.max_prefill_tokens is None
|
if server_args.max_prefill_tokens is None
|
||||||
else server_args.max_prefill_tokens
|
else server_args.max_prefill_tokens
|
||||||
)
|
)
|
||||||
@@ -222,30 +222,29 @@ class ModelTpServer:
|
|||||||
# Run decode batch
|
# Run decode batch
|
||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
# Run a few decode batches continuously for reducing overhead
|
# Run a few decode batches continuously for reducing overhead
|
||||||
for _ in range(10):
|
for _ in range(global_config.num_continue_decode_steps):
|
||||||
self.num_generated_tokens += len(self.running_batch.reqs)
|
self.num_generated_tokens += len(self.running_batch.reqs)
|
||||||
self.forward_decode_batch(self.running_batch)
|
self.forward_decode_batch(self.running_batch)
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||||
if self.decode_forward_ct % 40 == 0:
|
num_used = self.max_total_num_tokens - (
|
||||||
num_used = self.max_total_num_tokens - (
|
self.token_to_kv_pool.available_size()
|
||||||
self.token_to_kv_pool.available_size()
|
+ self.tree_cache.evictable_size()
|
||||||
+ self.tree_cache.evictable_size()
|
)
|
||||||
)
|
throughput = self.num_generated_tokens / (
|
||||||
throughput = self.num_generated_tokens / (
|
time.time() - self.last_stats_tic
|
||||||
time.time() - self.last_stats_tic
|
)
|
||||||
)
|
self.num_generated_tokens = 0
|
||||||
self.num_generated_tokens = 0
|
self.last_stats_tic = time.time()
|
||||||
self.last_stats_tic = time.time()
|
logger.info(
|
||||||
logger.info(
|
f"[gpu_id={self.gpu_id}] Decode batch. "
|
||||||
f"[gpu_id={self.gpu_id}] Decode batch. "
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
f"#token: {num_used}, "
|
||||||
f"#token: {num_used}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"gen throughput (token/s): {throughput:.2f}, "
|
||||||
f"gen throughput (token/s): {throughput:.2f}, "
|
f"#queue-req: {len(self.forward_queue)}"
|
||||||
f"#queue-req: {len(self.forward_queue)}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if self.running_batch.is_empty():
|
if self.running_batch.is_empty():
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
@@ -344,7 +343,7 @@ class ModelTpServer:
|
|||||||
if self.running_batch:
|
if self.running_batch:
|
||||||
available_size -= sum(
|
available_size -= sum(
|
||||||
[
|
[
|
||||||
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
|
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
|
||||||
for r in self.running_batch.reqs
|
for r in self.running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -358,7 +357,7 @@ class ModelTpServer:
|
|||||||
req.prefix_indices = req.prefix_indices[:-delta]
|
req.prefix_indices = req.prefix_indices[:-delta]
|
||||||
if req.image_offset is not None:
|
if req.image_offset is not None:
|
||||||
req.image_offset += delta
|
req.image_offset += delta
|
||||||
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
|
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
||||||
# Need at least one token to compute logits
|
# Need at least one token to compute logits
|
||||||
req.extend_input_len = 1
|
req.extend_input_len = 1
|
||||||
req.prefix_indices = req.prefix_indices[:-1]
|
req.prefix_indices = req.prefix_indices[:-1]
|
||||||
@@ -366,7 +365,7 @@ class ModelTpServer:
|
|||||||
req.image_offset += 1
|
req.image_offset += 1
|
||||||
|
|
||||||
if (
|
if (
|
||||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
|
||||||
< available_size
|
< available_size
|
||||||
and (
|
and (
|
||||||
req.extend_input_len + new_batch_input_tokens
|
req.extend_input_len + new_batch_input_tokens
|
||||||
@@ -378,7 +377,7 @@ class ModelTpServer:
|
|||||||
available_size += delta
|
available_size += delta
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
|
||||||
< available_size
|
< available_size
|
||||||
):
|
):
|
||||||
# Undo locking
|
# Undo locking
|
||||||
@@ -389,7 +388,7 @@ class ModelTpServer:
|
|||||||
# Add this request to the running batch
|
# Add this request to the running batch
|
||||||
can_run_list.append(req)
|
can_run_list.append(req)
|
||||||
new_batch_total_tokens += (
|
new_batch_total_tokens += (
|
||||||
req.extend_input_len + req.max_new_tokens()
|
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||||
)
|
)
|
||||||
new_batch_input_tokens += req.extend_input_len
|
new_batch_input_tokens += req.extend_input_len
|
||||||
else:
|
else:
|
||||||
@@ -403,9 +402,6 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
running_req = (
|
|
||||||
0 if self.running_batch is None else len(self.running_batch.reqs)
|
|
||||||
)
|
|
||||||
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
||||||
self.tree_cache_metrics["total"] += (
|
self.tree_cache_metrics["total"] += (
|
||||||
hit_tokens + new_batch_input_tokens
|
hit_tokens + new_batch_input_tokens
|
||||||
@@ -420,7 +416,7 @@ class ModelTpServer:
|
|||||||
f"#new-token: {new_batch_input_tokens}, "
|
f"#new-token: {new_batch_input_tokens}, "
|
||||||
f"#cached-token: {hit_tokens}, "
|
f"#cached-token: {hit_tokens}, "
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"#running-req: {running_req}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
|
|||||||
@@ -8,45 +8,45 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ReqToTokenPool:
|
class ReqToTokenPool:
|
||||||
def __init__(self, size, max_context_len):
|
"""A memory pool that maps a request to its token locations."""
|
||||||
|
|
||||||
|
def __init__(self, size: int, max_context_len: int):
|
||||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
||||||
self.can_use_mem_size = size
|
|
||||||
self.req_to_token = torch.empty(
|
self.req_to_token = torch.empty(
|
||||||
(size, max_context_len), dtype=torch.int32, device="cuda"
|
(size, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
self.can_use_mem_size = size
|
||||||
|
|
||||||
def alloc(self, need_size):
|
def alloc(self, need_size: int):
|
||||||
if need_size > self.can_use_mem_size:
|
if need_size > self.can_use_mem_size:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
|
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
|
||||||
self.mem_state[select_index] = False
|
self.mem_state[select_index] = False
|
||||||
self.can_use_mem_size -= need_size
|
self.can_use_mem_size -= need_size
|
||||||
|
|
||||||
return select_index.to(torch.int32)
|
return select_index
|
||||||
|
|
||||||
def free(self, free_index):
|
def free(self, free_index: int):
|
||||||
|
self.mem_state[free_index] = True
|
||||||
if isinstance(free_index, (int,)):
|
if isinstance(free_index, (int,)):
|
||||||
self.can_use_mem_size += 1
|
self.can_use_mem_size += 1
|
||||||
else:
|
else:
|
||||||
self.can_use_mem_size += free_index.shape[0]
|
self.can_use_mem_size += free_index.shape[0]
|
||||||
|
|
||||||
self.mem_state[free_index] = True
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_state.fill_(True)
|
self.mem_state.fill_(True)
|
||||||
self.can_use_mem_size = len(self.mem_state)
|
self.can_use_mem_size = len(self.mem_state)
|
||||||
|
|
||||||
|
|
||||||
class TokenToKVPool:
|
class TokenToKVPool:
|
||||||
|
"""A memory pool that maps a token to its kv cache locations"""
|
||||||
|
|
||||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
# This can be promised:
|
|
||||||
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||||
self.can_use_mem_size = self.size
|
|
||||||
|
|
||||||
# [size, key/value, head_num, head_dim] for each layer
|
# [size, key/value, head_num, head_dim] for each layer
|
||||||
self.kv_data = [
|
self.kv_data = [
|
||||||
@@ -58,6 +58,7 @@ class TokenToKVPool:
|
|||||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||||
self.prefetch_chunk_size = 512
|
self.prefetch_chunk_size = 512
|
||||||
|
|
||||||
|
self.can_use_mem_size = self.size
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id):
|
def get_key_buffer(self, layer_id):
|
||||||
@@ -66,6 +67,9 @@ class TokenToKVPool:
|
|||||||
def get_value_buffer(self, layer_id):
|
def get_value_buffer(self, layer_id):
|
||||||
return self.kv_data[layer_id][:, 1]
|
return self.kv_data[layer_id][:, 1]
|
||||||
|
|
||||||
|
def available_size(self):
|
||||||
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||||
|
|
||||||
def alloc(self, need_size):
|
def alloc(self, need_size):
|
||||||
buffer_len = len(self.prefetch_buffer)
|
buffer_len = len(self.prefetch_buffer)
|
||||||
if need_size <= buffer_len:
|
if need_size <= buffer_len:
|
||||||
@@ -75,13 +79,13 @@ class TokenToKVPool:
|
|||||||
|
|
||||||
addition_size = need_size - buffer_len
|
addition_size = need_size - buffer_len
|
||||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||||
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size]
|
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
||||||
select_index = select_index.to(torch.int32)
|
|
||||||
|
|
||||||
if select_index.shape[0] < addition_size:
|
if select_index.shape[0] < addition_size:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.add_refs(select_index)
|
self.mem_state[select_index] = False
|
||||||
|
self.can_use_mem_size -= len(select_index)
|
||||||
|
|
||||||
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
||||||
ret_index = self.prefetch_buffer[:need_size]
|
ret_index = self.prefetch_buffer[:need_size]
|
||||||
@@ -89,16 +93,9 @@ class TokenToKVPool:
|
|||||||
|
|
||||||
return ret_index
|
return ret_index
|
||||||
|
|
||||||
def available_size(self):
|
def free(self, free_index: torch.Tensor):
|
||||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
self.mem_state[free_index] = True
|
||||||
|
self.can_use_mem_size += len(free_index)
|
||||||
def add_refs(self, token_index: torch.Tensor):
|
|
||||||
self.can_use_mem_size -= len(token_index)
|
|
||||||
self.mem_state[token_index] = False
|
|
||||||
|
|
||||||
def dec_refs(self, token_index: torch.Tensor):
|
|
||||||
self.can_use_mem_size += len(token_index)
|
|
||||||
self.mem_state[token_index] = True
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_state.fill_(True)
|
self.mem_state.fill_(True)
|
||||||
|
|||||||
Reference in New Issue
Block a user