Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -135,25 +135,22 @@ class ForwardBatch:
|
||||
|
||||
# Init position information
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(prefix_len, prefix_len + extend_len)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
ret.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, device=device
|
||||
)
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
|
||||
Reference in New Issue
Block a user