Fix cuda graph with flashinfer (#675)

This commit is contained in:
Lianmin Zheng
2024-07-20 02:43:55 -07:00
committed by GitHub
parent 06487f126e
commit 490a1f39dd
2 changed files with 3 additions and 3 deletions

View File

@@ -150,8 +150,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index]
if bs != raw_bs:
self.seq_lens.zero_()
self.position_ids_offsets.fill_(1)
self.seq_lens.fill_(1)
self.position_ids_offsets.zero_()
self.out_cache_loc.zero_()
# Common inputs