[BugFix] Fix eagle3 accuracy problem when enforce_eager=True (#4521)

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
def main():
    prompts = [
        "The future of AI is",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    # Create an LLM.
    llm = LLM(
            model="meta-llama/Llama-3.1-8B-Instruct",
            tensor_parallel_size=1,
            speculative_config={
                "method": "eagle3",
                "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
                "num_speculative_tokens": 3
            },
            enforce_eager=True,
        )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    print(f"Outputs: {outputs}")
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
zhaomingyu13
2025-12-06 17:31:26 +08:00
committed by GitHub
parent 3480094d7c
commit cb42564942
2 changed files with 12 additions and 5 deletions

View File

@@ -110,7 +110,7 @@ def test_eagle_correctness(
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.
''' '''
pytest.skip("exist OOM error") pytest.skip("To be aligned with GPU")
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False) ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm

View File

@@ -79,7 +79,7 @@ class EagleProposer(Proposer):
dtype=torch.int32) dtype=torch.int32)
attn_mask_len = self.vllm_config.model_config.max_model_len attn_mask_len = self.vllm_config.model_config.max_model_len
self.attn_mask_builder = AttentionMaskBuilder( self.attn_mask_builder = AttentionMaskBuilder(
attn_mask_len, self.vllm_config.model_config.dtype) attn_mask_len, self.vllm_config.model_config.dtype, device=device)
def load_model(self, model: nn.Module) -> None: def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set( target_attn_layer_names = set(
@@ -430,9 +430,7 @@ class EagleProposer(Proposer):
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item() max_query_len = query_lens.max().item()
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( attn_mask = self.runner.attn_mask
seq_lens, target_positions, self.vllm_config.model_config.dtype,
self.device)
common_attn_metadata = AscendCommonAttentionMetadata( common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens.to(device), query_start_loc=cu_num_tokens.to(device),
@@ -507,9 +505,15 @@ class EagleProposer(Proposer):
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
query_lens.fill_(1) query_lens.fill_(1)
attn_metadata.query_lens = query_lens attn_metadata.query_lens = query_lens
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
attn_metadata.seq_lens_list = seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range( for now_speculative in range(
self.vllm_config.speculative_config.num_speculative_tokens - self.vllm_config.speculative_config.num_speculative_tokens -
@@ -536,6 +540,9 @@ class EagleProposer(Proposer):
# TODO: Increment the sequence lengths. # TODO: Increment the sequence lengths.
attn_metadata.seq_lens += 1 attn_metadata.seq_lens += 1
attn_metadata.seq_lens_list = [
_ + 1 for _ in attn_metadata.seq_lens_list
]
# TODO: Consider max model length. # TODO: Consider max model length.
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
# self.max_model_len) # self.max_model_len)