Skip unnecessary penalizer (#1707)

This commit is contained in:
Lianmin Zheng
2024-10-18 17:54:03 -07:00
committed by GitHub
parent bc12d4033f
commit 2bcfba1b08
7 changed files with 104 additions and 75 deletions

View File

@@ -150,6 +150,7 @@ class Scheduler:
nccl_port=port_args.nccl_port,
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.device = self.tp_worker.device
# Get token and memory info from the model worker
(
@@ -758,9 +759,7 @@ class Scheduler:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(
len(next_token_ids), device=next_token_ids.device
),
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
)
@@ -828,7 +827,7 @@ class Scheduler:
# Move logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()