[Fix] Fix flashinfer cpu <-> gpu synchronization (#8340)
This commit is contained in:
@@ -729,10 +729,12 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
|
||||
seq_lens_cpu = None
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
|
||||
if pp_proxy_tensors:
|
||||
for key in self.pp_proxy_tensors.keys():
|
||||
@@ -766,7 +768,7 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
||||
self.capture_forward_mode,
|
||||
forward_batch.spec_info,
|
||||
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
)
|
||||
|
||||
# Store fields
|
||||
|
||||
Reference in New Issue
Block a user