Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -515,11 +515,11 @@ class ScheduleBatch:
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
:pre_len
|
||||
] = req.prefix_indices
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
||||
req.prefix_indices
|
||||
)
|
||||
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
|
||||
@@ -535,10 +535,15 @@ class ScheduleBatch:
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Set fields
|
||||
with out_cache_loc.device:
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
@@ -782,8 +787,8 @@ class ScheduleBatch:
|
||||
return
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(
|
||||
keep_indices, dtype=torch.int32, device=self.seq_lens.device
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
|
||||
Reference in New Issue
Block a user