[Fix] Fix seq_lens_sum for cuda graph runner in padded cases (#1789)

This commit is contained in:
Lianmin Zheng
2024-10-24 21:26:05 -07:00
committed by GitHub
parent 1701b0db31
commit 86a2c473b7

View File

@@ -307,7 +307,7 @@ class CudaGraphRunner:
bs,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum,
forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
)