forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
33
vllm-v0.6.2/tests/samplers/test_ignore_eos.py
Normal file
33
vllm-v0.6.2/tests/samplers/test_ignore_eos.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Make sure ignore_eos works.
|
||||
|
||||
Run `pytest tests/samplers/test_ignore_eos.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
# We also test with llama because it has generation_config to specify EOS
|
||||
# (past regression).
|
||||
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [512])
|
||||
def test_ignore_eos(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
||||
for prompt in example_prompts:
|
||||
ignore_eos_output = vllm_model.model.generate(
|
||||
prompt, sampling_params=sampling_params)
|
||||
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
|
||||
assert output_length == max_tokens
|
||||
Reference in New Issue
Block a user