Minor Optimizations in Schedule Batch (#8724)

Co-authored-by: Suruchi Shah <surshah@linkedin.com>
This commit is contained in:
Lianmin Zheng
2025-08-08 16:10:16 -07:00
committed by GitHub
parent 6642e3a295
commit f352b793be

View File

@@ -37,6 +37,7 @@ import logging
import threading import threading
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to( input_ids_tensor = torch.tensor(
self.device, non_blocking=True list(chain.from_iterable(input_ids)), dtype=torch.int64
) ).to(self.device, non_blocking=True)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )