Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -150,6 +150,7 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.device = self.tp_worker.device
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
@@ -758,9 +759,7 @@ class Scheduler:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs[
|
||||
torch.arange(
|
||||
len(next_token_ids), device=next_token_ids.device
|
||||
),
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
)
|
||||
@@ -828,7 +827,7 @@ class Scheduler:
|
||||
# Move logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user