diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 981040d0d..4958c6d04 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module): aux_pruned_states = torch.cat(aux_pruned_states, dim=-1) hidden_states_to_store = ( aux_pruned_states[sample_indices] - if sample_indices + if sample_indices is not None else aux_pruned_states ) else: hidden_states_to_store = ( pruned_states[sample_indices] - if sample_indices + if sample_indices is not None else pruned_states ) else: