[Minor] Improve the style and fix flaky tests (#1584)
This commit is contained in:
@@ -747,7 +747,9 @@ class ScheduleBatch:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
new_indices = torch.tensor(
|
||||||
|
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
|
||||||
|
)
|
||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||||
self.seq_lens = self.seq_lens[new_indices]
|
self.seq_lens = self.seq_lens[new_indices]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ class PrefillAdder:
|
|||||||
if not insert_sort:
|
if not insert_sort:
|
||||||
self.req_states.append((tokens_left, tokens_occupied))
|
self.req_states.append((tokens_left, tokens_occupied))
|
||||||
else:
|
else:
|
||||||
|
i = 0
|
||||||
for i in range(len(self.req_states)):
|
for i in range(len(self.req_states)):
|
||||||
if tokens_left <= self.req_states[i][0]:
|
if tokens_left <= self.req_states[i][0]:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
# Get token and memory info from the tp worker
|
# Get token and memory info from the model worker
|
||||||
(
|
(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
@@ -976,7 +976,7 @@ def run_scheduler_process(
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
pipe_writer: multiprocessing.connection.Connection,
|
pipe_writer,
|
||||||
):
|
):
|
||||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
|
|||||||
@@ -31,10 +31,13 @@ class ReqToTokenPool:
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.device = device
|
self.device = device
|
||||||
self.free_slots = list(range(size))
|
|
||||||
self.req_to_token = torch.empty(
|
self.req_to_token = torch.empty(
|
||||||
(size, max_context_len), dtype=torch.int32, device=device
|
(size, max_context_len), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
self.free_slots = list(range(size))
|
||||||
|
|
||||||
|
def available_size(self):
|
||||||
|
return len(self.free_slots)
|
||||||
|
|
||||||
def alloc(self, need_size: int) -> List[int]:
|
def alloc(self, need_size: int) -> List[int]:
|
||||||
if need_size > len(self.free_slots):
|
if need_size > len(self.free_slots):
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class SamplingBatchInfo:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
with torch.device("cuda"):
|
with batch.input_ids.device:
|
||||||
temperatures = torch.tensor(
|
temperatures = torch.tensor(
|
||||||
[r.sampling_params.temperature for r in reqs],
|
[r.sampling_params.temperature for r in reqs],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
|
|||||||
@@ -594,7 +594,9 @@ def set_weight_attrs(
|
|||||||
|
|
||||||
|
|
||||||
def broadcast_pyobj(
|
def broadcast_pyobj(
|
||||||
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
|
data: List[Any],
|
||||||
|
rank: int,
|
||||||
|
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
):
|
):
|
||||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class TestTritonAttnBackend(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
assert output_throughput > 154, f"{output_throughput=}"
|
assert output_throughput > 153, f"{output_throughput=}"
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
|||||||
Reference in New Issue
Block a user