From 9244f27f0af24deb199921c32e24f2491380e016 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 6 Oct 2024 00:10:48 -0700 Subject: [PATCH] [Minor] Improve the style and fix flaky tests (#1584) --- python/sglang/srt/managers/schedule_batch.py | 4 +++- python/sglang/srt/managers/schedule_policy.py | 1 + python/sglang/srt/managers/scheduler.py | 4 ++-- python/sglang/srt/mem_cache/memory_pool.py | 5 ++++- python/sglang/srt/sampling/sampling_batch_info.py | 2 +- python/sglang/srt/utils.py | 4 +++- test/srt/test_triton_attn_backend.py | 2 +- 7 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 64cacd4c2..98fc5581c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -747,7 +747,9 @@ class ScheduleBatch: return self.reqs = [self.reqs[i] for i in unfinished_indices] - new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") + new_indices = torch.tensor( + unfinished_indices, dtype=torch.int32, device=self.seq_lens.device + ) self.req_pool_indices = self.req_pool_indices[new_indices] self.seq_lens = self.seq_lens[new_indices] self.out_cache_loc = None diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 22c18f2e4..f8ec2a778 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -218,6 +218,7 @@ class PrefillAdder: if not insert_sort: self.req_states.append((tokens_left, tokens_occupied)) else: + i = 0 for i in range(len(self.req_states)): if tokens_left <= self.req_states[i][0]: break diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f4dcbb650..7f764260c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -144,7 +144,7 @@ class Scheduler: ) self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group - # Get token and memory info from the tp worker + # Get token and memory info from the model worker ( self.max_total_num_tokens, self.max_prefill_tokens, @@ -976,7 +976,7 @@ def run_scheduler_process( port_args: PortArgs, gpu_id: int, tp_rank: int, - pipe_writer: multiprocessing.connection.Connection, + pipe_writer, ): configure_logger(server_args, prefix=f" TP{tp_rank}") suppress_other_loggers() diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index a4c90be1e..152868e0d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -31,10 +31,13 @@ class ReqToTokenPool: self.size = size self.max_context_len = max_context_len self.device = device - self.free_slots = list(range(size)) self.req_to_token = torch.empty( (size, max_context_len), dtype=torch.int32, device=device ) + self.free_slots = list(range(size)) + + def available_size(self): + return len(self.free_slots) def alloc(self, need_size: int) -> List[int]: if need_size > len(self.free_slots): diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 247c15d8e..de781acb3 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -40,7 +40,7 @@ class SamplingBatchInfo: @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs - with torch.device("cuda"): + with batch.input_ids.device: temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index dedcb9dfc..bc1366b10 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -594,7 +594,9 @@ def set_weight_attrs( def broadcast_pyobj( - data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup + data: List[Any], + rank: int, + dist_group: Optional[torch.distributed.ProcessGroup] = None, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index 646754478..55df1951f 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -26,7 +26,7 @@ class TestTritonAttnBackend(unittest.TestCase): ) if is_in_ci(): - assert output_throughput > 154, f"{output_throughput=}" + assert output_throughput > 153, f"{output_throughput=}" def test_mmlu(self): model = DEFAULT_MODEL_NAME_FOR_TEST