diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index f2bb8a4..0d29cb1 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -259,8 +259,7 @@ jobs: # TODO: switch hf to modelscope VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ pytest -sv tests/e2e/singlecard/test_ilama_lora.py - # TODO(sss): guided decoding doesn't work, fix it later - # pytest -sv tests/e2e/singlecard/test_guided_decoding.py + pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_camem.py pytest -sv tests/e2e/singlecard/ \ --ignore=tests/e2e/singlecard/test_offline_inference.py \ @@ -278,8 +277,7 @@ jobs: # TODO: switch hf to modelscope VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ pytest -sv tests/e2e/singlecard/test_ilama_lora.py - # guided decoding doesn't work, fix it later - # pytest -sv tests/e2e/singlecard/test_guided_decoding.py + pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_camem.py pytest -sv tests/e2e/singlecard/test_prompt_embedding.py pytest -sv tests/e2e/singlecard/ \ diff --git a/tests/e2e/singlecard/test_guided_decoding.py b/tests/e2e/singlecard/test_guided_decoding.py index 0725812..9d103a5 100644 --- a/tests/e2e/singlecard/test_guided_decoding.py +++ b/tests/e2e/singlecard/test_guided_decoding.py @@ -28,13 +28,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams from tests.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" -MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -GuidedDecodingBackendV0 = [ - "outlines", - "lm-format-enforcer", - "xgrammar", -] -GuidedDecodingBackendV1 = ["xgrammar", "guidance:disable-any-whitespace"] +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" + +GuidedDecodingBackendV0 = ["outlines", "lm-format-enforcer", "xgrammar"] +GuidedDecodingBackendV1 = ["xgrammar", "guidance"] GuidedDecodingBackend = list( set(GuidedDecodingBackendV0 + GuidedDecodingBackendV1)) @@ -87,26 +84,25 @@ def sample_json_schema(): } +def check_backend(guided_decoding_backend: str): + if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( + "VLLM_USE_V1") == "0": + pytest.skip(f"{guided_decoding_backend} does not support v0, skip it.") + if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( + "VLLM_USE_V1") == "1": + pytest.skip(f"{guided_decoding_backend} does not support v1, skip it.") + + @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) def test_guided_json_completion(guided_decoding_backend: str, sample_json_schema): - if guided_decoding_backend == "xgrammar": - # xgrammar does not support json schema, will fall back to outlines, skip it - pytest.skip( - f"{guided_decoding_backend} will fall back to outlines, skip it") - if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( - "VLLM_USE_V1") == "0": - # guidance does not support on v0, skip it - pytest.skip( - f"{guided_decoding_backend} does not support on v0, skip it") - if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( - "VLLM_USE_V1") == "1": - pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") + check_backend(guided_decoding_backend) sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=500, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + with VllmRunner( MODEL_NAME, seed=0, @@ -138,19 +134,13 @@ def test_guided_json_completion(guided_decoding_backend: str, @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) def test_guided_regex(guided_decoding_backend: str, sample_regex): - if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( - "VLLM_USE_V1") == "0": - # guidance does not support on v0, skip it - pytest.skip( - f"{guided_decoding_backend} does not support on v0, skip it") - if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( - "VLLM_USE_V1") == "1": - pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") + check_backend(guided_decoding_backend) + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - regex=sample_regex, )) with VllmRunner( MODEL_NAME, seed=0,