Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,124 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.conftest import VllmRunner
from vllm import SamplingParams
from vllm.logprobs import FlatLogprobs
MODELS = ["facebook/opt-125m"]
MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
NUM_TOP_LOGPROBS = 5
NUM_PROMPT_LOGPROBS = 7
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
def test_get_prompt_logprobs(
hf_runner,
@pytest.mark.parametrize("greedy", [True, False])
@pytest.mark.parametrize("flat_logprobs", [True, False])
def test_ranks(
vllm_runner,
model,
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
greedy,
flat_logprobs,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
sampling_params = SamplingParams(
temperature=0.0 if greedy else 1.0,
top_p=1.0,
max_tokens=MAX_TOKENS,
logprobs=NUM_TOP_LOGPROBS,
prompt_logprobs=NUM_PROMPT_LOGPROBS,
flat_logprobs=flat_logprobs,
)
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
max_tokens = 5
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model
assert len(results) == len(example_prompt_tokens)
for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
decode_tokens, _, decode_logprobs, prompt_logprobs = result
vllm_model = vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_top_logprobs,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
# Ensure the return type of logprobs is accurate
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text
output_string_from_most_likely_tokens = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
for prompt_logprobs in result.prompt_logprobs[1:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert (len(prompt_logprobs) == num_top_logprobs
or len(prompt_logprobs) == num_top_logprobs + 1)
# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
########################
# Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
logprob = sample_logprob.logprob
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned "
" to the user.")
########################
assert len(prompt_tokens) == len(prompt_logprobs)
# No logprob for first prompt token
assert not prompt_logprobs[0]
for position, (token, logprobs) in enumerate(
zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
):
# Ensure logprobs of prompt token is always returned
logprob = logprobs.get(token)
assert logprob is not None
assert logprob.rank >= 1
# Ensure # of returned logprobs should be
# either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
# Ensure top NUM_PROMPT_LOGPROBS is always extracted
assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
{logprob.rank for logprob in logprobs.values()}
)
# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
token_ids = vllm_result.prompt_token_ids
prompt_logprobs = vllm_result.prompt_logprobs
# The first token doesn't have logprob.
assert prompt_logprobs[0] is None
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
assert token_id in logprob_dict
def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
########################
# Check sample logprobs
########################
assert len(decode_tokens) == len(decode_logprobs)
for position, (token, logprobs) in enumerate(
zip(decode_tokens, decode_logprobs)
):
# Ensure logprobs of chosen token is always returned
logprob = logprobs.get(token)
assert logprob is not None
if greedy:
# For greedy sampling, all chosen logprob should be top ranked
assert logprob.rank == 1
else:
assert logprob.rank >= 1
# Ensure # of returned logprobs should be
# either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
# Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
{logprob.rank for logprob in logprobs.values()}
)