Fix padding in the cuda graph (#1469)

This commit is contained in:
Lianmin Zheng
2024-09-19 01:52:15 -07:00
committed by GitHub
parent 446ea33277
commit 2d346a57c2
2 changed files with 17 additions and 16 deletions

View File

@@ -108,6 +108,10 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
]
self.compile_bs = (
[
bs
@@ -118,21 +122,8 @@ class CudaGraphRunner:
else []
)
# Common inputs
self.max_bs = max(self.capture_bs)
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.out_cache_loc = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
# Attention backend
self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -141,6 +132,16 @@ class CudaGraphRunner:
if self.use_torch_compile:
set_torch_compile_config()
# Common inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
# Capture
try:
self.capture()