fix(srt): check if sample_indices is not None before usage. (#5633)

This commit is contained in:
aoshen524
2025-04-26 22:51:01 -04:00
committed by GitHub
parent d7b1ce65a5
commit 9ad28f639e

View File

@@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module):
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
hidden_states_to_store = (
aux_pruned_states[sample_indices]
if sample_indices
if sample_indices is not None
else aux_pruned_states
)
else:
hidden_states_to_store = (
pruned_states[sample_indices]
if sample_indices
if sample_indices is not None
else pruned_states
)
else: