[Bugfix] eagle and eagle3 spec decode failures and enable e2e test (#2979)

### What this PR does / why we need it?
- Fix the bug https://github.com/vllm-project/vllm-ascend/issues/2978
- Enable e2e test,
- Adapt to scenarios where Speculative tokens are greater than 2,
- Fix the bug that causes Eagle3 inference failures under high
concurrency and improve the acceptance rate of draft models, by
https://github.com/vllm-project/vllm-ascend/pull/2794

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with new added/existing test.

Co-authored-by: hukongyi
[hukongyi@cmbchina.com](mailto:hukongyi@cmbchina.com)
Co-authored-by: guanyuzhu
[zhuguanyu@huawei.com](mailto:zhuguanyu@huawei.com)
Co-authored-by: liumail680
[liumail680@163.com](mailto:liumail680@163.com)


- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-09-25 14:39:12 +08:00
committed by GitHub
parent ac1c2cd9ac
commit 2a9d02e080
2 changed files with 40 additions and 25 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
import random
from typing import Any
@@ -9,6 +10,8 @@ from vllm import LLM, SamplingParams
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
def test_prompts():
@@ -99,7 +102,6 @@ def test_ngram_correctness(
assert matches > int(0.7 * len(ref_outputs))
@pytest.mark.skipif(True, reason="oom in CI, fix me")
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness(
test_prompts: list[list[dict[str, Any]]],
@@ -111,8 +113,6 @@ def test_eagle_correctness(
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
if not use_eagle3:
pytest.skip("Not current support for the test.")
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@@ -121,7 +121,6 @@ def test_eagle_correctness(
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
trust_remote_code=True,
enable_chunked_prefill=True,
max_num_seqs=1,
max_num_batched_tokens=2048,

View File

@@ -22,6 +22,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import vllm_version_is
PADDING_SLOT_ID = -1
@@ -139,8 +140,6 @@ class EagleProposer(Proposer):
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
if self.name == SpecDcodeType.EAGLE:
raise NotImplementedError("Eagle Is Not Supported Yet.")
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
next_token_ids: list[int] = []
@@ -355,8 +354,12 @@ class EagleProposer(Proposer):
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
attn_metadata_i = self.runner.attn_metadata_builder.build(
common_attn_metadata, self.runner.get_model())
if vllm_version_is("0.10.2"):
builder = self.runner.attn_groups[0][0].metadata_builder
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_i = builder.build(0, common_attn_metadata,
self.runner.get_model())
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@@ -418,16 +421,19 @@ class EagleProposer(Proposer):
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids[0]
self.input_ids[last_token_indices] = next_token_ids
seq_lens = (target_positions[last_token_indices] + 1).int()
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, target_positions, self.vllm_config.model_config.dtype,
self.device)
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:batch_size + 1],
query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size +
1],
seq_lens_cpu=self.runner.seq_lens_cpu,
query_start_loc=cu_num_tokens.to(device),
query_start_loc_cpu=cu_num_tokens,
seq_lens_cpu=seq_lens.cpu(),
max_query_len=max_query_len,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
@@ -436,15 +442,19 @@ class EagleProposer(Proposer):
get_device_tensor(),
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
attn_mask=attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
attn_metadata = self.runner.attn_metadata_builder.build(
common_attn_metadata, self.runner.model)
if vllm_version_is("0.10.2"):
builder = self.runner.attn_groups[0][0].metadata_builder
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model())
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@@ -471,7 +481,10 @@ class EagleProposer(Proposer):
hidden_states=self.hidden_states[:num_input_tokens],
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if vllm_version_is("0.10.2"):
logits = self.model.compute_logits(sample_hidden_states, None)
else:
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
@@ -501,9 +514,8 @@ class EagleProposer(Proposer):
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
if self.vllm_config.speculative_config.num_speculative_tokens > 2:
raise ValueError("Speculative tokens > 2 are not supported yet.")
query_lens.fill_(1)
attn_metadata.query_lens = query_lens
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range(
@@ -558,9 +570,8 @@ class EagleProposer(Proposer):
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
positions = positions_cpu.to(device)
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
attn_metadata.seq_lens, positions,
attn_metadata.seq_lens, positions_cpu,
self.vllm_config.model_config.dtype, self.device)
attn_metadata.attn_mask = attn_mask
@@ -577,8 +588,12 @@ class EagleProposer(Proposer):
hidden_states=self.hidden_states[:input_batch_size],
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
if vllm_version_is("0.10.2"):
logits = self.model.compute_logits(
last_hidden_states[:batch_size], None)
else:
logits = self.model.compute_logits(
last_hidden_states[:batch_size])
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
@@ -652,7 +667,8 @@ class EagleProposer(Proposer):
dtype=torch.int32,
device=out_tensor.device) + offset_tensor
values_to_store = torch.tensor(
index_start, dtype=torch.int32,
index_start + global_start_offset,
dtype=torch.int32,
device=out_tensor.device) + offset_tensor
mask = (target_indices >= start_pos) & \
(target_indices < end_pos) & \