[Fix] Fix NaN issues by fixing the cuda graph padding values for flashinfer (#1779)

This commit is contained in:
Lianmin Zheng
2024-10-24 04:16:59 -07:00
committed by GitHub
parent 72e7b57a75
commit 0089c4bc96

View File

@@ -290,7 +290,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value)
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
# Common inputs