From cb42564942b2b2eed11d26c8e86b4c869d5f4c4b Mon Sep 17 00:00:00 2001 From: zhaomingyu13 Date: Sat, 6 Dec 2025 17:31:26 +0800 Subject: [PATCH] [BugFix] Fix eagle3 accuracy problem when enforce_eager=True (#4521) ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? def main(): prompts = [ "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=1, speculative_config={ "method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" "num_speculative_tokens": 3 }, enforce_eager=True, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) print(f"Outputs: {outputs}") for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zhaomingyu Co-authored-by: wangxiyuan --- .../spec_decode_v1/test_v1_spec_decode.py | 2 +- vllm_ascend/spec_decode/eagle_proposer.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 0902fe6d..5d74b5d4 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -110,7 +110,7 @@ def test_eagle_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. ''' - pytest.skip("exist OOM error") + pytest.skip("To be aligned with GPU") ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 6b47e6bf..af7d3689 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -79,7 +79,7 @@ class EagleProposer(Proposer): dtype=torch.int32) attn_mask_len = self.vllm_config.model_config.max_model_len self.attn_mask_builder = AttentionMaskBuilder( - attn_mask_len, self.vllm_config.model_config.dtype) + attn_mask_len, self.vllm_config.model_config.dtype, device=device) def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( @@ -430,9 +430,7 @@ class EagleProposer(Proposer): query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, target_positions, self.vllm_config.model_config.dtype, - self.device) + attn_mask = self.runner.attn_mask common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=cu_num_tokens.to(device), @@ -507,9 +505,15 @@ class EagleProposer(Proposer): attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] + attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[ + 1:].tolist() + attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size + attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens query_lens.fill_(1) attn_metadata.query_lens = query_lens + attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] + attn_metadata.seq_lens_list = seq_lens.tolist() attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill for now_speculative in range( self.vllm_config.speculative_config.num_speculative_tokens - @@ -536,6 +540,9 @@ class EagleProposer(Proposer): # TODO: Increment the sequence lengths. attn_metadata.seq_lens += 1 + attn_metadata.seq_lens_list = [ + _ + 1 for _ in attn_metadata.seq_lens_list + ] # TODO: Consider max model length. # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, # self.max_model_len)