Adjust InputeMetadata and ScheduleBatch (#981)
This commit is contained in:
@@ -307,7 +307,6 @@ class ScheduleBatch:
|
||||
input_ids: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
prefix_lens: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
extend_num_tokens: int = None
|
||||
@@ -316,11 +315,6 @@ class ScheduleBatch:
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[int]] = None
|
||||
image_offsets: List[int] = None
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
@@ -412,59 +406,40 @@ class ScheduleBatch:
|
||||
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = self.batch_size()
|
||||
reqs = self.reqs
|
||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
prefix_indices = [r.prefix_indices for r in reqs]
|
||||
|
||||
# Handle prefix
|
||||
extend_lens = []
|
||||
prefix_lens = []
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
||||
|
||||
for i, req in enumerate(reqs):
|
||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||
extend_lens.append(len(input_ids[i]))
|
||||
|
||||
if len(prefix_indices[i]) == 0:
|
||||
prefix_lens.append(0)
|
||||
else:
|
||||
prefix_lens.append(len(prefix_indices[i]))
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: len(prefix_indices[i])
|
||||
] = prefix_indices[i]
|
||||
|
||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||
|
||||
# Allocate memory
|
||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.input_ids)
|
||||
ext_len = seq_len - pre_len
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
if pre_len > 0:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
:pre_len
|
||||
] = req.prefix_indices
|
||||
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + ext_len]
|
||||
)
|
||||
pt += ext_len
|
||||
|
||||
# Set fields
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
|
||||
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
||||
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
(r.image_offset - p_len) if r.image_offset is not None else 0
|
||||
for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
@@ -642,7 +617,6 @@ class ScheduleBatch:
|
||||
]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens.add_(1)
|
||||
self.prefix_lens = None
|
||||
|
||||
# Alloc mem
|
||||
bs = self.batch_size()
|
||||
@@ -667,7 +641,6 @@ class ScheduleBatch:
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.input_ids = None
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
@@ -692,7 +665,6 @@ class ScheduleBatch:
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = torch.concat(
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user