fix(srt): check if sample_indices is not None before usage. (#5633)
This commit is contained in:
@@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module):
|
|||||||
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
||||||
hidden_states_to_store = (
|
hidden_states_to_store = (
|
||||||
aux_pruned_states[sample_indices]
|
aux_pruned_states[sample_indices]
|
||||||
if sample_indices
|
if sample_indices is not None
|
||||||
else aux_pruned_states
|
else aux_pruned_states
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states_to_store = (
|
hidden_states_to_store = (
|
||||||
pruned_states[sample_indices]
|
pruned_states[sample_indices]
|
||||||
if sample_indices
|
if sample_indices is not None
|
||||||
else pruned_states
|
else pruned_states
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user