[Fix] Fix seq_lens_sum for cuda graph runner in padded cases (#1789)
This commit is contained in:
@@ -307,7 +307,7 @@ class CudaGraphRunner:
|
||||
bs,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
self.encoder_lens,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user