Make req_pool_indices on CPU (#960)
This commit is contained in:
@@ -19,7 +19,6 @@ class GlobalConfig:
|
||||
self.init_new_token_ratio = 0.7
|
||||
self.base_min_new_token_ratio = 0.1
|
||||
self.new_token_ratio_decay = 0.001
|
||||
self.new_token_ratio_recovery = 0.05
|
||||
|
||||
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||
# This can improve the speed for large batch sizes during prefill.
|
||||
|
||||
@@ -100,6 +100,9 @@ class Req:
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||
|
||||
# Memory info
|
||||
self.req_pool_idx = None
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
# ----- | surr_ids |
|
||||
@@ -321,6 +324,9 @@ class ScheduleBatch:
|
||||
return_logprob=return_logprob,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs) if self.reqs is not None else 0
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
@@ -328,52 +334,22 @@ class ScheduleBatch:
|
||||
# Return whether batch has at least 1 streaming request
|
||||
return any(r.stream for r in self.reqs)
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = len(self.reqs)
|
||||
reqs = self.reqs
|
||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
prefix_indices = [r.prefix_indices for r in reqs]
|
||||
|
||||
# Handle prefix
|
||||
flatten_input_ids = []
|
||||
extend_lens = []
|
||||
prefix_lens = []
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
||||
|
||||
def alloc_req_slots(self, num_reqs):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"Out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`."
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
for i in range(bs):
|
||||
flatten_input_ids.extend(input_ids[i])
|
||||
extend_lens.append(len(input_ids[i]))
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
|
||||
if len(prefix_indices[i]) == 0:
|
||||
prefix_lens.append(0)
|
||||
else:
|
||||
prefix_lens.append(len(prefix_indices[i]))
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
: len(prefix_indices[i])
|
||||
] = prefix_indices[i]
|
||||
|
||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||
|
||||
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
||||
|
||||
# Allocate memory
|
||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
if out_cache_loc is None:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
logger.error("Prefill out of memory. Try to lower your batch size.")
|
||||
@@ -381,40 +357,11 @@ class ScheduleBatch:
|
||||
self.tree_cache.pretty_print()
|
||||
exit(1)
|
||||
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
logit_bias = None
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
if logit_bias is None:
|
||||
logit_bias = torch.zeros(
|
||||
(bs, vocab_size), dtype=torch.float32, device=device
|
||||
)
|
||||
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
flatten_input_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.req_pool_indices = req_pool_indices
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.position_ids_offsets = position_ids_offsets
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
return out_cache_loc
|
||||
|
||||
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
||||
device = "cuda"
|
||||
bs, reqs = self.batch_size(), self.reqs
|
||||
self.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
@@ -436,10 +383,78 @@ class ScheduleBatch:
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
self.logit_bias = None
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(bs, vocab_size), dtype=torch.float32, device=device
|
||||
)
|
||||
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = self.batch_size()
|
||||
reqs = self.reqs
|
||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
prefix_indices = [r.prefix_indices for r in reqs]
|
||||
|
||||
# Handle prefix
|
||||
extend_lens = []
|
||||
prefix_lens = []
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
||||
|
||||
for i, req in enumerate(reqs):
|
||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||
extend_lens.append(len(input_ids[i]))
|
||||
|
||||
if len(prefix_indices[i]) == 0:
|
||||
prefix_lens.append(0)
|
||||
else:
|
||||
prefix_lens.append(len(prefix_indices[i]))
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: len(prefix_indices[i])
|
||||
] = prefix_indices[i]
|
||||
|
||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||
|
||||
# Allocate memory
|
||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
|
||||
# Set fields
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
|
||||
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
|
||||
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = len(self.reqs)
|
||||
bs = self.batch_size()
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
@@ -464,7 +479,6 @@ class ScheduleBatch:
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
@@ -482,20 +496,20 @@ class ScheduleBatch:
|
||||
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][: seq_lens_cpu[idx]]
|
||||
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
@@ -533,8 +547,6 @@ class ScheduleBatch:
|
||||
jump_forward_reqs = []
|
||||
filter_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
req_pool_indices_cpu = None
|
||||
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.jump_forward_map is not None:
|
||||
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
||||
@@ -584,13 +596,11 @@ class ScheduleBatch:
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=cur_all_ids,
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
)
|
||||
|
||||
# unlock the last node
|
||||
@@ -626,14 +636,8 @@ class ScheduleBatch:
|
||||
self.prefix_lens = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
logger.error("Decode out of memory. Try to lower your batch size.")
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
exit(1)
|
||||
bs = self.batch_size()
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
|
||||
@@ -200,7 +200,6 @@ class ModelTpServer:
|
||||
)
|
||||
self.new_token_ratio = self.min_new_token_ratio
|
||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
try:
|
||||
@@ -625,13 +624,12 @@ class ModelTpServer:
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
def cache_filled_batch(self, batch: ScheduleBatch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=tuple(req.input_ids),
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
del_in_memory_pool=False,
|
||||
old_last_node=req.last_node,
|
||||
)
|
||||
@@ -639,7 +637,7 @@ class ModelTpServer:
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||
# Check if decode out of memory
|
||||
@@ -782,14 +780,13 @@ class ModelTpServer:
|
||||
# Remove finished reqs
|
||||
if finished_indices:
|
||||
# Update radix cache
|
||||
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
)
|
||||
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
"""Memory pool."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,34 +28,29 @@ class ReqToTokenPool:
|
||||
|
||||
def __init__(self, size: int, max_context_len: int):
|
||||
self.size = size
|
||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
||||
self.free_slots = list(range(size))
|
||||
self.req_to_token = torch.empty(
|
||||
(size, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.can_use_mem_size = size
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > self.can_use_mem_size:
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = (
|
||||
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
|
||||
)
|
||||
self.mem_state[select_index] = False
|
||||
self.can_use_mem_size -= need_size
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index):
|
||||
self.mem_state[free_index] = True
|
||||
if isinstance(free_index, (int,)):
|
||||
self.can_use_mem_size += 1
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.can_use_mem_size += free_index.shape[0]
|
||||
self.free_slots.extend(free_index)
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(True)
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
|
||||
Reference in New Issue
Block a user