Re-introduce get_cuda_graph_seq_len_fill_value (#1783)

This commit is contained in:
Lianmin Zheng
2024-10-24 13:30:11 -07:00
committed by GitHub
parent 605972195b
commit 384d85ba35
5 changed files with 19 additions and 2 deletions

View File

@@ -134,7 +134,11 @@ class CudaGraphRunner:
self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = 1
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile:
@@ -287,7 +291,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value)
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
# Common inputs