init
This commit is contained in:
0
tests/spec_decode/e2e/__init__.py
Normal file
0
tests/spec_decode/e2e/__init__.py
Normal file
305
tests/spec_decode/e2e/conftest.py
Normal file
305
tests/spec_decode/e2e/conftest.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import asyncio
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit)
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
|
||||
class AsyncLLM:
|
||||
"""AsyncLLM
|
||||
|
||||
Note: Current LLM class in vllm don't support async mode, for test purpose,
|
||||
we implement async one in here. Maybe we could move to
|
||||
vllm/entrypoints/llm.py in future.
|
||||
|
||||
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
|
||||
to make to work in async mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
self.engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
engine_use_ray=True,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
|
||||
llm_engine = AsyncLLMEngine.from_engine_args(
|
||||
self.engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
if prompts is None:
|
||||
raise ValueError("prompts must be provided.")
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
elif isinstance(sampling_params,
|
||||
list) and len(sampling_params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and "
|
||||
"sampling_params must be the same.")
|
||||
|
||||
async def get_output(prompt, sampling_param) -> str:
|
||||
request_id = random_uuid()
|
||||
results_generator = llm_engine.generate(prompt, sampling_param,
|
||||
request_id)
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
return final_output
|
||||
|
||||
outputs = []
|
||||
try:
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
res = asyncio.run(get_output(prompt, sampling_params))
|
||||
outputs.append(res)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
seed):
|
||||
return create_llm_generator("baseline", request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator("test", request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, test_llm_kwargs,
|
||||
seed)
|
||||
|
||||
|
||||
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, distinct_llm_kwargs,
|
||||
seed):
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
}
|
||||
test_name = request.node.name
|
||||
|
||||
def generator_inner():
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
use_async = False
|
||||
if "use_async" in kwargs:
|
||||
use_async = kwargs.pop("use_async")
|
||||
print(f'{use_async=}')
|
||||
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
def generator_outer():
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
return generator_outer
|
||||
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
tokens = []
|
||||
token_ids = []
|
||||
for llm in llm_generator():
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
|
||||
return tokens, token_ids
|
||||
|
||||
|
||||
def get_logprobs_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> List[List[Dict[int, Logprob]]]:
|
||||
"""Returns a dict of (token_id: Logprob) for each generated position, for
|
||||
each sequence in the batch.
|
||||
"""
|
||||
for llm in llm_generator():
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
|
||||
del llm
|
||||
|
||||
return logprobs
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
|
||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
threshold_bytes: int,
|
||||
timeout_s: float = 120) -> None:
|
||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||
# context.
|
||||
nvmlInit()
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output = {}
|
||||
output_raw = {}
|
||||
for device in devices:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
gb_used = mem_info.used / 2**30
|
||||
output_raw[device] = gb_used
|
||||
output[device] = f'{gb_used:.02f}'
|
||||
|
||||
print('gpu memory used (GB): ', end='')
|
||||
for k, v in output.items():
|
||||
print(f'{k}={v}; ', end='')
|
||||
print('')
|
||||
|
||||
dur_s = time.time() - start_time
|
||||
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
||||
print(f'Done waiting for free GPU memory on devices {devices=} '
|
||||
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
||||
break
|
||||
|
||||
if dur_s >= timeout_s:
|
||||
raise ValueError(f'Memory of devices {devices=} not free after '
|
||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
|
||||
time.sleep(5)
|
||||
176
tests/spec_decode/e2e/test_compatibility.py
Normal file
176
tests/spec_decode/e2e/test_compatibility.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Expect failure as spec decode not supported by
|
||||
# Ray backend.
|
||||
"worker_use_ray": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_ray(test_llm_generator):
|
||||
"""Verify that speculative decoding with Ray fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
finally:
|
||||
# we need to free up ray resource,
|
||||
# so that latter test could use the gpu we allocated here
|
||||
import ray
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"enable_chunked_prefill": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
||||
"""Verify that speculative decoding with chunked prefill fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding and chunked prefill"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Speculative max model len > overridden max model len should raise.
|
||||
"max_model_len": 128,
|
||||
"speculative_max_model_len": 129,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > draft max model len should raise.
|
||||
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||
"speculative_max_model_len": 2048 + 1,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > target max model len should raise.
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
|
||||
"speculative_max_model_len": 4096 + 1,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
||||
"""Verify that speculative decoding validates speculative_max_model_len.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
|
||||
"""Verify that speculative decoding with block manager v1 fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding requires usage of the V2"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
335
tests/spec_decode/e2e/test_logprobs.py
Normal file
335
tests/spec_decode/e2e/test_logprobs.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import math
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_logprobs_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [6])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int,
|
||||
num_logprobs: int):
|
||||
"""Verify output logprobs are equal with and without spec decode.
|
||||
This specifies a number of logprobs >1.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
logprob_rank=num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
||||
case where the sampled token is not in top-k logprobs.
|
||||
|
||||
Ideally, this test should validate equality with non-spec by getting
|
||||
logprobs. This is left as future improvement.
|
||||
"""
|
||||
batch_size = 8
|
||||
max_output_len = output_len
|
||||
force_output_len = True
|
||||
logprob_rank = 5
|
||||
|
||||
temperature = 1.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
logprobs=logprob_rank,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
num_returned_logprobs = [
|
||||
len(logprob_dict) for seq_logprobs in spec_batch_logprobs
|
||||
for logprob_dict in seq_logprobs
|
||||
]
|
||||
|
||||
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
||||
# sampled token is not in top-k).
|
||||
assert any([
|
||||
num_returned > logprob_rank for num_returned in num_returned_logprobs
|
||||
])
|
||||
|
||||
|
||||
def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
logprob_rank: int = 1):
|
||||
"""Helper method that compares the logprobs outputs of both the baseline LLM
|
||||
and the test LLM. It asserts greedy equality of the logprobs when the
|
||||
temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
logprobs=logprob_rank,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_logprobs) == len(prompts)
|
||||
assert len(spec_batch_logprobs) == len(prompts)
|
||||
|
||||
# For each sequence in the batch.
|
||||
for i, (baseline_logprobs, spec_logprobs) in enumerate(
|
||||
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# Map rank to token/logprob in spec output.
|
||||
spec_rank_to_token_id = {
|
||||
value.rank: key
|
||||
for key, value in spec_pos_logprobs.items()
|
||||
}
|
||||
spec_rank_to_logprob = {
|
||||
value.rank: value.logprob
|
||||
for key, value in spec_pos_logprobs.items()
|
||||
}
|
||||
|
||||
# Map rank to token/logprob in baseline output.
|
||||
baseline_rank_to_token_id = {
|
||||
value.rank: key
|
||||
for key, value in baseline_pos_logprobs.items()
|
||||
}
|
||||
baseline_rank_to_logprob = {
|
||||
value.rank: value.logprob
|
||||
for key, value in baseline_pos_logprobs.items()
|
||||
}
|
||||
|
||||
# Assert set of ranks returned is equal.
|
||||
assert set(spec_rank_to_token_id.keys()) == set(
|
||||
baseline_rank_to_token_id.keys())
|
||||
|
||||
# Assert each logprob/token id is correct, keyed by rank.
|
||||
for rank in sorted(set(spec_rank_to_token_id.keys())):
|
||||
assert spec_rank_to_token_id[
|
||||
rank] == baseline_rank_to_token_id[rank], f"{rank}"
|
||||
assert math.isclose(
|
||||
a=spec_rank_to_logprob[rank],
|
||||
b=baseline_rank_to_logprob[rank],
|
||||
abs_tol=1e-1,
|
||||
)
|
||||
579
tests/spec_decode/e2e/test_multistep_correctness.py
Normal file
579
tests/spec_decode/e2e/test_multistep_correctness.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""The tests in this file verify end-to-end speculative decoding correctness.
|
||||
|
||||
This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality. This gives us good coverage of temp=0.
|
||||
|
||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||
output distribution is the same with spec decode vs. no spec decode (this would
|
||||
be prohibitively expensive to run with a real model).
|
||||
|
||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||
distributions of the target model and proposal model be deterministic given the
|
||||
same input. vLLM largely guarantees this.
|
||||
|
||||
@cadedaniel has seen cases where the output probabilities of a draft/target
|
||||
model change slightly with certain batch sizes or prompts, even with Torch
|
||||
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
|
||||
determinism in on-device batched operations, a bug in vLLM's spec decode
|
||||
implementation, or the "hardware numerics" limitations. Either way, rejection
|
||||
sampling ensures the output distribution matches the target model, but it breaks
|
||||
greedy-equality tests for those batch sizes/prompts.
|
||||
"""
|
||||
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import (get_output_from_llm_generator,
|
||||
run_greedy_equality_correctness_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
{
|
||||
# Verify the detokenizer assertions in the test work when spec
|
||||
# decode is disabled.
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
batch_size: int):
|
||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||
generates the correct number of tokens (via ignore_eos=True), and that the
|
||||
detokenization matches HF transformers.
|
||||
"""
|
||||
output_len = 32
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
# Expect a generation for each prompt in the batch.
|
||||
assert len(batch_token_ids) == len(prompts)
|
||||
|
||||
# Expect each generation to have expected number of tokens (note ignore_eos
|
||||
# is True).
|
||||
assert [len(token_ids)
|
||||
for token_ids in batch_token_ids] == ([output_len] * batch_size)
|
||||
|
||||
# Expect detokenized string to match.
|
||||
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
||||
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
|
||||
expected_tokens = tok.decode(actual_token_ids)
|
||||
print(f"{actual_token_ids=}")
|
||||
assert actual_tokens.strip() == expected_tokens.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Use AsyncLLM engine
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
||||
baseline_llm_generator,
|
||||
batch_size: int):
|
||||
"""Verify spec decode works well with async LLM engine.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use long output len for the small model test.
|
||||
1536,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model with batch size of one.
|
||||
|
||||
Since this test is cheaper than other e2e correctness tests, we generate
|
||||
with a higher output_len.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model and large batch size.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("max_output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
max_output_len: int):
|
||||
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||
sampling respects the EOS token.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use decently long output len for a high quality test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||
separate from large BS tests to make identifying the source of bugs easier.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||
This is the closest test to a real production workload.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# As of this writing, vLLM only compiles with these 3 block sizes by
|
||||
# default.
|
||||
{
|
||||
"block_size": 8,
|
||||
},
|
||||
{
|
||||
"block_size": 16,
|
||||
},
|
||||
{
|
||||
"block_size": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_different_block_size(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality over different block sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# This must be a good bit larger than speculative_max_model_len so that
|
||||
# we can test the case where all seqs are skipped, but still small to
|
||||
# ensure fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
are skipped in speculative decoding.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with many different values of k.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
172
tests/spec_decode/e2e/test_ngram_correctness.py
Normal file
172
tests/spec_decode/e2e/test_ngram_correctness.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
|
||||
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
|
||||
Since there is no model is needed for generate the proposal, we could make
|
||||
the testcase much simpler than drafter multi-step one.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various ngram sizes / speculative sizes
|
||||
|
||||
With those tests, we can say at least, ngram spec would not break the correctess
|
||||
for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 1,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
Reference in New Issue
Block a user