Fix CUDA illegal memory access issues in speculative decoding (#10892)
This commit is contained in:
@@ -302,6 +302,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
self.positions.zero_()
|
||||||
|
|
||||||
num_tokens = bs * self.num_tokens_per_bs
|
num_tokens = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
|||||||
@@ -332,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
if bs * self.num_tokens_per_bs != num_tokens:
|
if bs * self.num_tokens_per_bs != num_tokens:
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
self.positions.zero_()
|
||||||
self.accept_length.fill_(1)
|
self.accept_length.fill_(1)
|
||||||
self.extend_seq_lens.fill_(1)
|
self.extend_seq_lens.fill_(1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user