Fix a bug in cuda graph runner (#1094)
This commit is contained in:
@@ -98,8 +98,8 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices = torch.zeros(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.position_ids_offsets = torch.ones(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.out_cache_loc = torch.zeros(
|
||||
@@ -201,7 +201,7 @@ class CudaGraphRunner:
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=0,
|
||||
positions=(seq_lens - 1).to(torch.int64),
|
||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
@@ -225,8 +225,8 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
||||
bs = self.batch_size_list[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.position_ids_offsets.zero_()
|
||||
self.seq_lens.zero_()
|
||||
self.position_ids_offsets.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
|
||||
Reference in New Issue
Block a user