Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -515,11 +515,11 @@ class ScheduleBatch:
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
:pre_len
|
||||
] = req.prefix_indices
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
||||
req.prefix_indices
|
||||
)
|
||||
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
|
||||
@@ -535,10 +535,15 @@ class ScheduleBatch:
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Set fields
|
||||
with out_cache_loc.device:
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
@@ -782,8 +787,8 @@ class ScheduleBatch:
|
||||
return
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(
|
||||
keep_indices, dtype=torch.int32, device=self.seq_lens.device
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
|
||||
@@ -150,6 +150,7 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.device = self.tp_worker.device
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
@@ -758,9 +759,7 @@ class Scheduler:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs[
|
||||
torch.arange(
|
||||
len(next_token_ids), device=next_token_ids.device
|
||||
),
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
)
|
||||
@@ -828,7 +827,7 @@ class Scheduler:
|
||||
# Move logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user