RadixCache method adjust (#977)
This commit is contained in:
@@ -124,7 +124,7 @@ class Req:
|
||||
# For vision input
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = 0
|
||||
self.image_offset = None
|
||||
self.pad_value = None
|
||||
|
||||
# Prefix info
|
||||
@@ -162,6 +162,13 @@ class Req:
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
max_prefix_ids = self.input_ids
|
||||
if self.return_logprob:
|
||||
max_prefix_ids = self.input_ids[: self.logprob_start_len]
|
||||
|
||||
return max_prefix_ids
|
||||
|
||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||
def init_incremental_detokenize(self):
|
||||
first_iter = self.surr_offset is None or self.read_offset is None
|
||||
@@ -444,7 +451,8 @@ class ScheduleBatch:
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
(r.image_offset - p_len) if r.image_offset is not None else 0
|
||||
for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
@@ -596,15 +604,7 @@ class ScheduleBatch:
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=cur_all_ids,
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
)
|
||||
|
||||
# unlock the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
||||
|
||||
# re-applying image padding
|
||||
if req.pixel_values is not None:
|
||||
@@ -621,8 +621,7 @@ class ScheduleBatch:
|
||||
jump_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
if len(filter_indices) < len(self.reqs):
|
||||
self.filter_batch(filter_indices)
|
||||
self.filter_batch(filter_indices)
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
@@ -644,6 +643,15 @@ class ScheduleBatch:
|
||||
] = self.out_cache_loc
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
if unfinished_indices is None or len(unfinished_indices) == 0:
|
||||
# Filter out all requests
|
||||
self.reqs = []
|
||||
return
|
||||
|
||||
if len(unfinished_indices) == len(self.reqs):
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
@@ -711,6 +719,7 @@ class ScheduleBatch:
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
|
||||
Reference in New Issue
Block a user