Adjust InputeMetadata and ScheduleBatch (#981)

This commit is contained in:
Liangsheng Yin
2024-08-08 01:11:22 -07:00
committed by GitHub
parent 20a4f927dc
commit 1ac304eeb4
4 changed files with 203 additions and 192 deletions

View File

@@ -307,7 +307,6 @@ class ScheduleBatch:
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
@@ -316,11 +315,6 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None
image_offsets: List[int] = None
# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
@@ -412,59 +406,40 @@ class ScheduleBatch:
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = self.batch_size()
reqs = self.reqs
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
prefix_indices = [r.prefix_indices for r in reqs]
# Handle prefix
extend_lens = []
prefix_lens = []
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
req_pool_indices_cpu = self.alloc_req_slots(bs)
for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i]
extend_lens.append(len(input_ids[i]))
if len(prefix_indices[i]) == 0:
prefix_lens.append(0)
else:
prefix_lens.append(len(prefix_indices[i]))
self.req_to_token_pool.req_to_token[req.req_pool_idx][
: len(prefix_indices[i])
] = prefix_indices[i]
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
# Allocate memory
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
req_pool_indices_cpu = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i, req in enumerate(reqs):
self.req_to_token_pool.req_to_token[req.req_pool_idx][
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]
req.req_pool_idx = req_pool_indices_cpu[i]
pre_len, seq_len = len(req.prefix_indices), len(req.input_ids)
ext_len = seq_len - pre_len
seq_lens.append(seq_len)
if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
out_cache_loc[pt : pt + ext_len]
)
pt += ext_len
# Set fields
with torch.device("cuda"):
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, prefix_lens)
]
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
@@ -642,7 +617,6 @@ class ScheduleBatch:
]
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1)
self.prefix_lens = None
# Alloc mem
bs = self.batch_size()
@@ -667,7 +641,6 @@ class ScheduleBatch:
self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices]
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
@@ -692,7 +665,6 @@ class ScheduleBatch:
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.prefix_lens = None
self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets]
)