From 2a9d02e08039749cf811c5fb190d4c8a950d792d Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Thu, 25 Sep 2025 14:39:12 +0800 Subject: [PATCH] [Bugfix] eagle and eagle3 spec decode failures and enable e2e test (#2979) ### What this PR does / why we need it? - Fix the bug https://github.com/vllm-project/vllm-ascend/issues/2978 - Enable e2e test, - Adapt to scenarios where Speculative tokens are greater than 2, - Fix the bug that causes Eagle3 inference failures under high concurrency and improve the acceptance rate of draft models, by https://github.com/vllm-project/vllm-ascend/pull/2794 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with new added/existing test. Co-authored-by: hukongyi [hukongyi@cmbchina.com](mailto:hukongyi@cmbchina.com) Co-authored-by: guanyuzhu [zhuguanyu@huawei.com](mailto:zhuguanyu@huawei.com) Co-authored-by: liumail680 [liumail680@163.com](mailto:liumail680@163.com) - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/f225ea7dd98e9f29752e5c032cd4a8ee1d712f16 --------- Signed-off-by: Icey <1790571317@qq.com> --- .../spec_decode_v1/test_v1_spec_decode.py | 7 +-- vllm_ascend/spec_decode/eagle_proposer.py | 58 ++++++++++++------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 9a1bfb8..97ecbf1 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import os import random from typing import Any @@ -9,6 +10,8 @@ from vllm import LLM, SamplingParams from tests.e2e.conftest import VllmRunner +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + @pytest.fixture def test_prompts(): @@ -99,7 +102,6 @@ def test_ngram_correctness( assert matches > int(0.7 * len(ref_outputs)) -@pytest.mark.skipif(True, reason="oom in CI, fix me") @pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) def test_eagle_correctness( test_prompts: list[list[dict[str, Any]]], @@ -111,8 +113,6 @@ def test_eagle_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. ''' - if not use_eagle3: - pytest.skip("Not current support for the test.") ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True) ref_outputs = ref_llm.chat(test_prompts, sampling_config) @@ -121,7 +121,6 @@ def test_eagle_correctness( spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() with VllmRunner( model_name, - trust_remote_code=True, enable_chunked_prefill=True, max_num_seqs=1, max_num_batched_tokens=2048, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index f993e3a..e342227 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -22,6 +22,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType +from vllm_ascend.utils import vllm_version_is PADDING_SLOT_ID = -1 @@ -139,8 +140,6 @@ class EagleProposer(Proposer): hidden_states: torch.Tensor = None, attn_metadata=None, aux_hidden_states: torch.Tensor = None): - if self.name == SpecDcodeType.EAGLE: - raise NotImplementedError("Eagle Is Not Supported Yet.") attn_metadata = self._get_eagle_atten_dict(scheduler_output) next_token_ids: list[int] = [] @@ -355,8 +354,12 @@ class EagleProposer(Proposer): decode_token_per_req=self.runner.decode_token_per_req, num_computed_tokens_cpu=None, seq_lens=None) - attn_metadata_i = self.runner.attn_metadata_builder.build( - common_attn_metadata, self.runner.get_model()) + if vllm_version_is("0.10.2"): + builder = self.runner.attn_groups[0][0].metadata_builder + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_i = builder.build(0, common_attn_metadata, + self.runner.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -418,16 +421,19 @@ class EagleProposer(Proposer): self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids[0] + self.input_ids[last_token_indices] = next_token_ids + seq_lens = (target_positions[last_token_indices] + 1).int() query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, target_positions, self.vllm_config.model_config.dtype, + self.device) common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc[:batch_size + 1], - query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + - 1], - seq_lens_cpu=self.runner.seq_lens_cpu, + query_start_loc=cu_num_tokens.to(device), + query_start_loc_cpu=cu_num_tokens, + seq_lens_cpu=seq_lens.cpu(), max_query_len=max_query_len, num_reqs=batch_size, num_actual_tokens=num_tokens, @@ -436,15 +442,19 @@ class EagleProposer(Proposer): get_device_tensor(), slot_mapping=target_slot_mapping, positions=target_positions, - attn_mask=self.runner.attn_mask, + attn_mask=attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, num_computed_tokens_cpu=None, seq_lens=None) # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build( - common_attn_metadata, self.runner.model) + if vllm_version_is("0.10.2"): + builder = self.runner.attn_groups[0][0].metadata_builder + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata = builder.build(0, common_attn_metadata, + self.runner.get_model()) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -471,7 +481,10 @@ class EagleProposer(Proposer): hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + if vllm_version_is("0.10.2"): + logits = self.model.compute_logits(sample_hidden_states, None) + else: + logits = self.model.compute_logits(sample_hidden_states) draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. @@ -501,9 +514,8 @@ class EagleProposer(Proposer): attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - - if self.vllm_config.speculative_config.num_speculative_tokens > 2: - raise ValueError("Speculative tokens > 2 are not supported yet.") + query_lens.fill_(1) + attn_metadata.query_lens = query_lens attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill for now_speculative in range( @@ -558,9 +570,8 @@ class EagleProposer(Proposer): self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states - positions = positions_cpu.to(device) attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - attn_metadata.seq_lens, positions, + attn_metadata.seq_lens, positions_cpu, self.vllm_config.model_config.dtype, self.device) attn_metadata.attn_mask = attn_mask @@ -577,8 +588,12 @@ class EagleProposer(Proposer): hidden_states=self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) + if vllm_version_is("0.10.2"): + logits = self.model.compute_logits( + last_hidden_states[:batch_size], None) + else: + logits = self.model.compute_logits( + last_hidden_states[:batch_size]) # TODO(wenlong): get more than one token for tree attention draft_token_ids = logits.argmax(dim=-1) @@ -652,7 +667,8 @@ class EagleProposer(Proposer): dtype=torch.int32, device=out_tensor.device) + offset_tensor values_to_store = torch.tensor( - index_start, dtype=torch.int32, + index_start + global_start_offset, + dtype=torch.int32, device=out_tensor.device) + offset_tensor mask = (target_indices >= start_pos) & \ (target_indices < end_pos) & \