From 6cc38b2bf31c141e3ae06ca8c1150e35dbeb5578 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 28 Aug 2024 00:54:26 -0700 Subject: [PATCH] [Minor] Add more type annotations (#1237) --- .../srt/model_executor/cuda_graph_runner.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 96c15849e..40c87af88 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -17,6 +17,7 @@ limitations under the License. import bisect from contextlib import contextmanager +from typing import Callable, List import torch from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): @contextmanager def patch_model( - model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator" + model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" ): backup_ca_comm = None try: - if use_compile: + if enable_compile: _to_torch(model) monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm @@ -67,7 +68,7 @@ def patch_model( else: yield model.forward finally: - if use_compile: + if enable_compile: _to_torch(model, reverse=True) monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -88,7 +89,7 @@ def set_torch_compile_config(): class CudaGraphRunner: def __init__( self, - model_runner, + model_runner: "ModelRunner", max_batch_size_to_capture: int, use_torch_compile: bool, disable_padding: bool, @@ -154,13 +155,13 @@ class CudaGraphRunner: if use_torch_compile: set_torch_compile_config() - def can_run(self, batch_size): + def can_run(self, batch_size: int): if self.disable_padding: return batch_size in self.graphs else: return batch_size <= self.max_bs - def capture(self, batch_size_list): + def capture(self, batch_size_list: List[int]): self.batch_size_list = batch_size_list with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream @@ -181,7 +182,7 @@ class CudaGraphRunner: self.output_buffers[bs] = output_buffers self.flashinfer_handlers[bs] = flashinfer_handler - def capture_one_batch_size(self, bs, forward): + def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream