[Minor] Improve the style and fix flaky tests (#1584)
This commit is contained in:
@@ -747,7 +747,9 @@ class ScheduleBatch:
|
||||
return
|
||||
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
new_indices = torch.tensor(
|
||||
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.out_cache_loc = None
|
||||
|
||||
@@ -218,6 +218,7 @@ class PrefillAdder:
|
||||
if not insert_sort:
|
||||
self.req_states.append((tokens_left, tokens_occupied))
|
||||
else:
|
||||
i = 0
|
||||
for i in range(len(self.req_states)):
|
||||
if tokens_left <= self.req_states[i][0]:
|
||||
break
|
||||
|
||||
@@ -144,7 +144,7 @@ class Scheduler:
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
|
||||
# Get token and memory info from the tp worker
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
self.max_total_num_tokens,
|
||||
self.max_prefill_tokens,
|
||||
@@ -976,7 +976,7 @@ def run_scheduler_process(
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
pipe_writer: multiprocessing.connection.Connection,
|
||||
pipe_writer,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
suppress_other_loggers()
|
||||
|
||||
@@ -31,10 +31,13 @@ class ReqToTokenPool:
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.free_slots = list(range(size))
|
||||
self.req_to_token = torch.empty(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
if need_size > len(self.free_slots):
|
||||
|
||||
@@ -40,7 +40,7 @@ class SamplingBatchInfo:
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
with torch.device("cuda"):
|
||||
with batch.input_ids.device:
|
||||
temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
|
||||
@@ -594,7 +594,9 @@ def set_weight_attrs(
|
||||
|
||||
|
||||
def broadcast_pyobj(
|
||||
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
|
||||
data: List[Any],
|
||||
rank: int,
|
||||
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user