[BugFix] Fix eagle3 accuracy problem when enforce_eager=True (#4521)
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
def main():
prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
speculative_config={
"method": "eagle3",
"model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
"num_speculative_tokens": 3
},
enforce_eager=True,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
print(f"Outputs: {outputs}")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -110,7 +110,7 @@ def test_eagle_correctness(
|
|||||||
Compare the outputs of a original LLM and a speculative LLM
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
should be the same when using eagle speculative decoding.
|
should be the same when using eagle speculative decoding.
|
||||||
'''
|
'''
|
||||||
pytest.skip("exist OOM error")
|
pytest.skip("To be aligned with GPU")
|
||||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
|
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class EagleProposer(Proposer):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
attn_mask_len = self.vllm_config.model_config.max_model_len
|
attn_mask_len = self.vllm_config.model_config.max_model_len
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(
|
self.attn_mask_builder = AttentionMaskBuilder(
|
||||||
attn_mask_len, self.vllm_config.model_config.dtype)
|
attn_mask_len, self.vllm_config.model_config.dtype, device=device)
|
||||||
|
|
||||||
def load_model(self, model: nn.Module) -> None:
|
def load_model(self, model: nn.Module) -> None:
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
@@ -430,9 +430,7 @@ class EagleProposer(Proposer):
|
|||||||
|
|
||||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||||
max_query_len = query_lens.max().item()
|
max_query_len = query_lens.max().item()
|
||||||
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
|
attn_mask = self.runner.attn_mask
|
||||||
seq_lens, target_positions, self.vllm_config.model_config.dtype,
|
|
||||||
self.device)
|
|
||||||
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=cu_num_tokens.to(device),
|
query_start_loc=cu_num_tokens.to(device),
|
||||||
@@ -507,9 +505,15 @@ class EagleProposer(Proposer):
|
|||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
attn_metadata.max_query_len = 1
|
attn_metadata.max_query_len = 1
|
||||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
|
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
|
||||||
|
1:].tolist()
|
||||||
|
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
|
||||||
|
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
|
||||||
query_lens.fill_(1)
|
query_lens.fill_(1)
|
||||||
attn_metadata.query_lens = query_lens
|
attn_metadata.query_lens = query_lens
|
||||||
|
|
||||||
|
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
|
||||||
|
attn_metadata.seq_lens_list = seq_lens.tolist()
|
||||||
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
for now_speculative in range(
|
for now_speculative in range(
|
||||||
self.vllm_config.speculative_config.num_speculative_tokens -
|
self.vllm_config.speculative_config.num_speculative_tokens -
|
||||||
@@ -536,6 +540,9 @@ class EagleProposer(Proposer):
|
|||||||
# TODO: Increment the sequence lengths.
|
# TODO: Increment the sequence lengths.
|
||||||
|
|
||||||
attn_metadata.seq_lens += 1
|
attn_metadata.seq_lens += 1
|
||||||
|
attn_metadata.seq_lens_list = [
|
||||||
|
_ + 1 for _ in attn_metadata.seq_lens_list
|
||||||
|
]
|
||||||
# TODO: Consider max model length.
|
# TODO: Consider max model length.
|
||||||
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||||
# self.max_model_len)
|
# self.max_model_len)
|
||||||
|
|||||||
Reference in New Issue
Block a user