[Feat] Support async_scheduler and disable_padded_drafter_batch in eagle (#4893)
### What this PR does / why we need it?
We refactored the eagle_proposer.py to adapt the framework of eagle.py
in vllm-v0.12.0, to support the logit of padded drafter batch and
async-scheduler.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -7,9 +7,10 @@ import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
@@ -115,41 +116,67 @@ def test_eagle_correctness(
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
'''
|
||||
pytest.skip("To be aligned with GPU")
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
# NOTE: e2e of eagle has many problems before.
|
||||
# We first check whether it is functioning properly.
|
||||
# Should fix the e2e with VllmRunner in future.
|
||||
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
max_num_seqs=1,
|
||||
max_num_batched_tokens=2048,
|
||||
gpu_memory_utilization=0.6,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
},
|
||||
max_model_len=128,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
spec_outputs = runner.model.chat(test_prompts, sampling_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompts = [{
|
||||
"role": "user",
|
||||
"content": "Hello, my name is"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "The president of the United States is"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "The capital of France is"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "The future of AI is"
|
||||
}]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[prompt],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
) for prompt in prompts
|
||||
]
|
||||
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=300,
|
||||
temperature=0.0,
|
||||
ignore_eos=False,
|
||||
)
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
data_parallel_size=1,
|
||||
disable_log_stats=False,
|
||||
max_model_len=4096,
|
||||
seed=1024,
|
||||
async_scheduling=True,
|
||||
compilation_config={
|
||||
"level": 3,
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_num_of_warmups": 1,
|
||||
"cudagraph_capture_sizes": [12],
|
||||
},
|
||||
speculative_config={
|
||||
"disable_padded_drafter_batch": False,
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
"draft_vocab_size": 128256,
|
||||
},
|
||||
)
|
||||
llm.generate(prompts, sampling_params)
|
||||
cleanup_dist_env_and_memory()
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
|
||||
Reference in New Issue
Block a user