[Minor] Add more type annotations (#1237)
This commit is contained in:
@@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||||
@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_model(
|
def patch_model(
|
||||||
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
|
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
||||||
):
|
):
|
||||||
backup_ca_comm = None
|
backup_ca_comm = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if use_compile:
|
if enable_compile:
|
||||||
_to_torch(model)
|
_to_torch(model)
|
||||||
monkey_patch_vllm_all_gather()
|
monkey_patch_vllm_all_gather()
|
||||||
backup_ca_comm = tp_group.ca_comm
|
backup_ca_comm = tp_group.ca_comm
|
||||||
@@ -67,7 +68,7 @@ def patch_model(
|
|||||||
else:
|
else:
|
||||||
yield model.forward
|
yield model.forward
|
||||||
finally:
|
finally:
|
||||||
if use_compile:
|
if enable_compile:
|
||||||
_to_torch(model, reverse=True)
|
_to_torch(model, reverse=True)
|
||||||
monkey_patch_vllm_all_gather(reverse=True)
|
monkey_patch_vllm_all_gather(reverse=True)
|
||||||
tp_group.ca_comm = backup_ca_comm
|
tp_group.ca_comm = backup_ca_comm
|
||||||
@@ -88,7 +89,7 @@ def set_torch_compile_config():
|
|||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_runner,
|
model_runner: "ModelRunner",
|
||||||
max_batch_size_to_capture: int,
|
max_batch_size_to_capture: int,
|
||||||
use_torch_compile: bool,
|
use_torch_compile: bool,
|
||||||
disable_padding: bool,
|
disable_padding: bool,
|
||||||
@@ -154,13 +155,13 @@ class CudaGraphRunner:
|
|||||||
if use_torch_compile:
|
if use_torch_compile:
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
def can_run(self, batch_size):
|
def can_run(self, batch_size: int):
|
||||||
if self.disable_padding:
|
if self.disable_padding:
|
||||||
return batch_size in self.graphs
|
return batch_size in self.graphs
|
||||||
else:
|
else:
|
||||||
return batch_size <= self.max_bs
|
return batch_size <= self.max_bs
|
||||||
|
|
||||||
def capture(self, batch_size_list):
|
def capture(self, batch_size_list: List[int]):
|
||||||
self.batch_size_list = batch_size_list
|
self.batch_size_list = batch_size_list
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
@@ -181,7 +182,7 @@ class CudaGraphRunner:
|
|||||||
self.output_buffers[bs] = output_buffers
|
self.output_buffers[bs] = output_buffers
|
||||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
self.flashinfer_handlers[bs] = flashinfer_handler
|
||||||
|
|
||||||
def capture_one_batch_size(self, bs, forward):
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user