Minor Optimizations in Schedule Batch (#8724)
Co-authored-by: Suruchi Shah <surshah@linkedin.com>
This commit is contained in:
@@ -37,6 +37,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
input_ids_tensor = torch.tensor(
|
||||||
self.device, non_blocking=True
|
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
||||||
)
|
).to(self.device, non_blocking=True)
|
||||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user