[Feat] Support async_scheduler and disable_padded_drafter_batch in eagle (#4893)

### What this PR does / why we need it?
We refactored the eagle_proposer.py to adapt the framework of eagle.py
in vllm-v0.12.0, to support the logit of padded drafter batch and
async-scheduler.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
anon189Ty
2025-12-16 22:06:40 +08:00
committed by GitHub
parent cee521bad5
commit 5b1da4e914
6 changed files with 577 additions and 403 deletions

View File

@@ -730,6 +730,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
if get_ascend_device_type() == AscendDeviceType._910_95:
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
# If it's necessary, the slots should be sliced.
torch_npu.npu_scatter_pa_kv_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens].contiguous(),
@@ -742,7 +745,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
value=value[:attn_metadata.num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
slot_indices=slots[:attn_metadata.num_actual_tokens])
return key, value
def forward_impl(

View File

@@ -119,6 +119,35 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
# This only use to eagle now. It will be use to enforce_eager in future.
return AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.
num_computed_tokens_cpu[:num_actual_reqs],
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
decode_token_per_req=self.decode_token_per_req,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=self.is_only_prefill,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self.
prefill_context_parallel_metadata,
)
def filter_chunked_req_indices(
seq_len: torch.Tensor,