Skip unnecessary penalizer (#1707)

This commit is contained in:
Lianmin Zheng
2024-10-18 17:54:03 -07:00
committed by GitHub
parent bc12d4033f
commit 2bcfba1b08
7 changed files with 104 additions and 75 deletions

View File

@@ -135,25 +135,22 @@ class ForwardBatch:
# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.tensor(
np.concatenate(
[
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
dtype=torch.int64,
device=device,
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device
)
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_seq_lens