[Fix] Fix NaN issues by fixing the cuda graph padding values for flashinfer (#1779)
This commit is contained in:
@@ -290,7 +290,7 @@ class CudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
self.seq_lens.fill_(1)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user