[Minor] Add more type annotations (#1237)
This commit is contained in:
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
import bisect
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
|
||||
@contextmanager
|
||||
def patch_model(
|
||||
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
|
||||
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
||||
):
|
||||
backup_ca_comm = None
|
||||
|
||||
try:
|
||||
if use_compile:
|
||||
if enable_compile:
|
||||
_to_torch(model)
|
||||
monkey_patch_vllm_all_gather()
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
@@ -67,7 +68,7 @@ def patch_model(
|
||||
else:
|
||||
yield model.forward
|
||||
finally:
|
||||
if use_compile:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=True)
|
||||
monkey_patch_vllm_all_gather(reverse=True)
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
@@ -88,7 +89,7 @@ def set_torch_compile_config():
|
||||
class CudaGraphRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_runner,
|
||||
model_runner: "ModelRunner",
|
||||
max_batch_size_to_capture: int,
|
||||
use_torch_compile: bool,
|
||||
disable_padding: bool,
|
||||
@@ -154,13 +155,13 @@ class CudaGraphRunner:
|
||||
if use_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
def can_run(self, batch_size):
|
||||
def can_run(self, batch_size: int):
|
||||
if self.disable_padding:
|
||||
return batch_size in self.graphs
|
||||
else:
|
||||
return batch_size <= self.max_bs
|
||||
|
||||
def capture(self, batch_size_list):
|
||||
def capture(self, batch_size_list: List[int]):
|
||||
self.batch_size_list = batch_size_list
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
@@ -181,7 +182,7 @@ class CudaGraphRunner:
|
||||
self.output_buffers[bs] = output_buffers
|
||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
||||
|
||||
def capture_one_batch_size(self, bs, forward):
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
|
||||
|
||||
Reference in New Issue
Block a user