diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 76ac04c7..0899fa33 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -16,6 +16,7 @@ # import pytest +import os from tests.e2e.singlecard.utils import (PROMPTS_LONG, PROMPTS_SHORT, LLMTestCase, gen_and_valid) @@ -133,3 +134,34 @@ def test_npugraph_ex_res_consistency(cur_case: LLMTestCase, monkeypatch): prompts=cur_case.prompts, sampling_params=cur_case.sampling_params, golden_answers=cur_case.golden_answers) + +# The accuracy has already been verified in the previous test case. +# This test case is used to check whether the functionality works properly +# after enabling the static kernel and whether it is uninstalled as expected. +@pytest.mark.parametrize("cur_case", [CASE_QWEN_EX]) +def test_npugraph_ex_with_static_kernel(cur_case: LLMTestCase, monkeypatch): + monkeypatch.delenv("HCCL_OP_EXPANSION_MODE", raising=False) + runner_kwargs = { + "model_name": cur_case.model, + "quantization": cur_case.quantization, + "max_model_len": 1024, + "compilation_config": { + "cudagraph_capture_sizes": [4, 8], + "cudagraph_mode": "FULL_DECODE_ONLY" + }, + "additional_config": { + "npugraph_ex_config": { + "enable": True, + "enable_static_kernel": True, + } + }, + } + gen_and_valid(runner_kwargs=runner_kwargs, + prompts=cur_case.prompts, + sampling_params=cur_case.sampling_params, + golden_answers=cur_case.golden_answers) + + # Check whether the static kernel is properly uninstall + ascend_home_path = os.environ["ASCEND_HOME_PATH"] + static_kernel_install_path = os.path.join(ascend_home_path, 'opp/static_kernel/ai_core') + assert not os.path.exists(static_kernel_install_path) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 7a029121..cefac33d 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -90,10 +90,10 @@ def npugraph_ex_compile( # affecting program execution. num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0 uniform_decode_query_len = num_spec_tokens + 1 - max_num_tokens = vllm_config.scheduler_config.max_num_seq * uniform_decode_query_len + max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len decode_cudagraph_batch_sizes = [ x - for x in vllm_config.compilation_config.cudagraph_capture_size + for x in vllm_config.compilation_config.cudagraph_capture_sizes if max_num_tokens >= x >= uniform_decode_query_len ] config.experimental_config.aclgraph._aclnn_static_shape_kernel_sym_value_range = decode_cudagraph_batch_sizes