init
This commit is contained in:
50
tests/samplers/test_ranks.py
Normal file
50
tests/samplers/test_ranks.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
|
||||
|
||||
## Test greedy logprobs ranks
|
||||
vllm_sampling_params = SamplingParams(temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
|
||||
vllm_sampling_params)
|
||||
for result in vllm_results:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks = 1
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert token in logprobs
|
||||
assert logprobs[token].rank == 1
|
||||
|
||||
## Test non-greedy logprobs ranks
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
for result in res:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert logprobs[token].rank >= 1
|
||||
Reference in New Issue
Block a user