RadixCache method adjust (#977)

This commit is contained in:
Liangsheng Yin
2024-08-07 15:52:24 -07:00
committed by GitHub
parent f724f1f1e9
commit 7623091d97
5 changed files with 140 additions and 118 deletions

View File

@@ -124,7 +124,7 @@ class Req:
# For vision input
self.pixel_values = None
self.image_size = None
self.image_offset = 0
self.image_offset = None
self.pad_value = None
# Prefix info
@@ -162,6 +162,13 @@ class Req:
def finished(self) -> bool:
return self.finished_reason is not None
def adjust_max_prefix_ids(self):
max_prefix_ids = self.input_ids
if self.return_logprob:
max_prefix_ids = self.input_ids[: self.logprob_start_len]
return max_prefix_ids
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None
@@ -444,7 +451,8 @@ class ScheduleBatch:
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
(r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, prefix_lens)
]
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.extend_num_tokens = extend_num_tokens
@@ -596,15 +604,7 @@ class ScheduleBatch:
req.vid += 1
# insert the old request into tree_cache
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
)
# unlock the last node
self.tree_cache.dec_lock_ref(req.last_node)
self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding
if req.pixel_values is not None:
@@ -621,8 +621,7 @@ class ScheduleBatch:
jump_forward_reqs.append(req)
filter_indices.remove(i)
if len(filter_indices) < len(self.reqs):
self.filter_batch(filter_indices)
self.filter_batch(filter_indices)
return jump_forward_reqs
@@ -644,6 +643,15 @@ class ScheduleBatch:
] = self.out_cache_loc
def filter_batch(self, unfinished_indices: List[int]):
if unfinished_indices is None or len(unfinished_indices) == 0:
# Filter out all requests
self.reqs = []
return
if len(unfinished_indices) == len(self.reqs):
# No need to filter
return
self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
@@ -711,6 +719,7 @@ class ScheduleBatch:
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
logits.div_(self.temperatures)