Fix cuda graph with flashinfer (#675)
This commit is contained in:
@@ -150,8 +150,8 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
||||
bs = self.batch_size_list[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.zero_()
|
||||
self.position_ids_offsets.fill_(1)
|
||||
self.seq_lens.fill_(1)
|
||||
self.position_ids_offsets.zero_()
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
|
||||
Reference in New Issue
Block a user