[EAGLE] Refactor code for page size > 1 & more simplifications (#7213)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:04:29 -07:00
committed by GitHub
parent 21615cc3fe
commit b1286a116a
8 changed files with 647 additions and 156 deletions

View File

@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None

View File

@@ -789,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend:
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template(
self,
@@ -809,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None

View File

@@ -2,9 +2,6 @@ from __future__ import annotations
"""
Support attention backend for FlashMLA.
#TODO
Enable speculative sampling in FlashMLA
"""
from dataclasses import dataclass

View File

@@ -784,14 +784,13 @@ class TritonMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
for i in range(self.speculative_num_steps):