Minor Optimizations in Schedule Batch (#8724)
Co-authored-by: Suruchi Shah <surshah@linkedin.com>
This commit is contained in:
@@ -37,6 +37,7 @@ import logging
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
input_ids_tensor = torch.tensor(
|
||||
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
||||
).to(self.device, non_blocking=True)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user