diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 16a2230f9..fd7630b3e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -37,6 +37,7 @@ import logging import threading from enum import Enum, auto from http import HTTPStatus +from itertools import chain from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union import numpy as np @@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.device, non_blocking=True ) - input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to( - self.device, non_blocking=True - ) + input_ids_tensor = torch.tensor( + list(chain.from_iterable(input_ids)), dtype=torch.int64 + ).to(self.device, non_blocking=True) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True )