Re-introduce get_cuda_graph_seq_len_fill_value (#1783)
This commit is contained in:
@@ -134,7 +134,11 @@ class CudaGraphRunner:
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
|
||||
self.seq_len_fill_value = 1
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
|
||||
if self.use_torch_compile:
|
||||
@@ -287,7 +291,7 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
|
||||
Reference in New Issue
Block a user