Minor Optimizations in Schedule Batch (#8724)

Co-authored-by: Suruchi Shah <surshah@linkedin.com>
This commit is contained in:
Lianmin Zheng
2025-08-08 16:10:16 -07:00
committed by GitHub
parent 6642e3a295
commit f352b793be

View File

@@ -37,6 +37,7 @@ import logging
import threading
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np
@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
self.device, non_blocking=True
)
input_ids_tensor = torch.tensor(
list(chain.from_iterable(input_ids)), dtype=torch.int64
).to(self.device, non_blocking=True)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)