Skip unnecessary penalizer (#1707)

This commit is contained in:
Lianmin Zheng
2024-10-18 17:54:03 -07:00
committed by GitHub
parent bc12d4033f
commit 2bcfba1b08
7 changed files with 104 additions and 75 deletions

View File

@@ -515,11 +515,11 @@ class ScheduleBatch:
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
req.prefix_indices
)
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
)
@@ -535,10 +535,15 @@ class ScheduleBatch:
pt += req.extend_input_len
# Set fields
with out_cache_loc.device:
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
@@ -782,8 +787,8 @@ class ScheduleBatch:
return
self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(
keep_indices, dtype=torch.int32, device=self.seq_lens.device
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]

View File

@@ -150,6 +150,7 @@ class Scheduler:
nccl_port=port_args.nccl_port,
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.device = self.tp_worker.device
# Get token and memory info from the model worker
(
@@ -758,9 +759,7 @@ class Scheduler:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(
len(next_token_ids), device=next_token_ids.device
),
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
)
@@ -828,7 +827,7 @@ class Scheduler:
# Move logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()