[Fix] Fix flashinfer cpu <-> gpu synchronization (#8340)

This commit is contained in:
DarkSharpness
2025-08-09 20:11:40 -07:00
committed by GitHub
parent 19bc77f05c
commit 7ba5ad5766
3 changed files with 65 additions and 24 deletions

View File

@@ -729,10 +729,12 @@ class CudaGraphRunner:
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
seq_lens_cpu = None
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
seq_lens_cpu = self.seq_lens_cpu[:bs]
if pp_proxy_tensors:
for key in self.pp_proxy_tensors.keys():
@@ -766,7 +768,7 @@ class CudaGraphRunner:
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
self.capture_forward_mode,
forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu[:bs],
seq_lens_cpu=seq_lens_cpu,
)
# Store fields