[Minor] Add more type annotations (#1237)

This commit is contained in:
Lianmin Zheng
2024-08-28 00:54:26 -07:00
committed by GitHub
parent 1ece2cda3d
commit 6cc38b2bf3

View File

@@ -17,6 +17,7 @@ limitations under the License.
import bisect
from contextlib import contextmanager
from typing import Callable, List
import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
@contextmanager
def patch_model(
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
):
backup_ca_comm = None
try:
if use_compile:
if enable_compile:
_to_torch(model)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
@@ -67,7 +68,7 @@ def patch_model(
else:
yield model.forward
finally:
if use_compile:
if enable_compile:
_to_torch(model, reverse=True)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm
@@ -88,7 +89,7 @@ def set_torch_compile_config():
class CudaGraphRunner:
def __init__(
self,
model_runner,
model_runner: "ModelRunner",
max_batch_size_to_capture: int,
use_torch_compile: bool,
disable_padding: bool,
@@ -154,13 +155,13 @@ class CudaGraphRunner:
if use_torch_compile:
set_torch_compile_config()
def can_run(self, batch_size):
def can_run(self, batch_size: int):
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs
def capture(self, batch_size_list):
def capture(self, batch_size_list: List[int]):
self.batch_size_list = batch_size_list
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
@@ -181,7 +182,7 @@ class CudaGraphRunner:
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs, forward):
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream