Fix CUDA illegal memory access issues in speculative decoding (#10892)
This commit is contained in:
@@ -302,6 +302,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
self.positions.zero_()
|
||||
|
||||
num_tokens = bs * self.num_tokens_per_bs
|
||||
|
||||
|
||||
@@ -332,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
if bs * self.num_tokens_per_bs != num_tokens:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
self.positions.zero_()
|
||||
self.accept_length.fill_(1)
|
||||
self.extend_seq_lens.fill_(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user