[Eagle] reduce one draft forward (#3468)
This commit is contained in:
@@ -947,7 +947,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
triton.next_power_of_2(bs),
|
triton.next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
|
|||||||
@@ -234,6 +234,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
token_list.append(tree_info[1])
|
token_list.append(tree_info[1])
|
||||||
parents_list.append(tree_info[2])
|
parents_list.append(tree_info[2])
|
||||||
|
|
||||||
|
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
||||||
|
if i == self.speculative_num_steps - 1:
|
||||||
|
break
|
||||||
|
|
||||||
# Set inputs
|
# Set inputs
|
||||||
forward_batch.input_ids = input_ids
|
forward_batch.input_ids = input_ids
|
||||||
forward_batch.out_cache_loc = out_cache_loc[
|
forward_batch.out_cache_loc = out_cache_loc[
|
||||||
|
|||||||
Reference in New Issue
Block a user