[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)

### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.

### Does this PR introduce _any_ user-facing change?
None

### How was this patch tested?
CI passed with new added test.


- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-03-02 17:54:25 +08:00
committed by GitHub
parent 8547520726
commit 16c879cdf7
14 changed files with 290 additions and 98 deletions

View File

@@ -28,7 +28,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": False}
},
additional_config={"npugraph_ex_config": {"enable": False}}
additional_config={"ascend_compilation_config": {"enable_npugraph_ex": False}}
) as runner:
no_sp_outputs = runner.model.generate(prompts, sampling_params)
@@ -41,7 +41,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": True}
},
additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}}
additional_config={"sp_threshold": 10, "ascend_compilation_config": {"enable_npugraph_ex": False}}
) as runner:
sp_outputs = runner.model.generate(
prompts, sampling_params)

View File

@@ -0,0 +1,34 @@
import pytest
import torch
from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.mark.parametrize(
("shape", "dtype", "scale"),
[
((1, 2048), torch.float16, 1.25),
((4000, 2048), torch.float16, 0.75),
((4, 2048), torch.bfloat16, 1.0),
],
)
@torch.inference_mode()
def test_muls_add_triton_correctness(shape, dtype, scale):
"""compare the correctness of muls_add_triton with the PyTorch baseline implementation."""
init_device_properties_triton()
device = "npu"
torch.manual_seed(0)
x = torch.randn(*shape, dtype=dtype, device=device)
y = torch.randn(*shape, dtype=dtype, device=device)
out_triton = muls_add_triton(x, y, scale)
out_ref = x * scale + y
rtol, atol = 1e-3, 1e-3
assert out_triton.shape == out_ref.shape
assert out_triton.dtype == out_ref.dtype
assert torch.allclose(out_triton, out_ref, rtol=rtol, atol=atol)

View File

@@ -153,7 +153,7 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
"max_model_len": 1024,
"compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"},
"quantization": cur_case.quantization,
"additional_config": {"npugraph_ex_config": {"enable": False}},
"additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": False}},
}
gen_and_valid(
runner_kwargs=runner_kwargs,
@@ -171,7 +171,7 @@ def test_npugraph_ex_res_consistency(cur_case: LLMTestCase, monkeypatch):
"quantization": cur_case.quantization,
"max_model_len": 1024,
"compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"},
"additional_config": {"npugraph_ex_config": {"enable": True}},
"additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": True}},
}
gen_and_valid(
runner_kwargs=runner_kwargs,
@@ -193,8 +193,8 @@ def test_npugraph_ex_with_static_kernel(cur_case: LLMTestCase, monkeypatch):
"max_model_len": 1024,
"compilation_config": {"cudagraph_capture_sizes": [4, 8], "cudagraph_mode": "FULL_DECODE_ONLY"},
"additional_config": {
"npugraph_ex_config": {
"enable": True,
"ascend_compilation_config": {
"enable_npugraph_ex": True,
"enable_static_kernel": True,
}
},