Use seq_len_fill_value in the cuda graph runners (#7233)
This commit is contained in:
@@ -612,7 +612,7 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
@@ -624,7 +624,7 @@ class CudaGraphRunner:
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
if pp_proxy_tensors:
|
||||
@@ -652,7 +652,7 @@ class CudaGraphRunner:
|
||||
bs,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||
self.encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
|
||||
Reference in New Issue
Block a user