[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)
### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -28,7 +28,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"pass_config": {"enable_sp": False}
|
||||
},
|
||||
additional_config={"npugraph_ex_config": {"enable": False}}
|
||||
additional_config={"ascend_compilation_config": {"enable_npugraph_ex": False}}
|
||||
) as runner:
|
||||
no_sp_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
@@ -41,7 +41,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"pass_config": {"enable_sp": True}
|
||||
},
|
||||
additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}}
|
||||
additional_config={"sp_threshold": 10, "ascend_compilation_config": {"enable_npugraph_ex": False}}
|
||||
) as runner:
|
||||
sp_outputs = runner.model.generate(
|
||||
prompts, sampling_params)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("shape", "dtype", "scale"),
|
||||
[
|
||||
((1, 2048), torch.float16, 1.25),
|
||||
((4000, 2048), torch.float16, 0.75),
|
||||
((4, 2048), torch.bfloat16, 1.0),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_muls_add_triton_correctness(shape, dtype, scale):
|
||||
"""compare the correctness of muls_add_triton with the PyTorch baseline implementation."""
|
||||
init_device_properties_triton()
|
||||
device = "npu"
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(*shape, dtype=dtype, device=device)
|
||||
y = torch.randn(*shape, dtype=dtype, device=device)
|
||||
|
||||
out_triton = muls_add_triton(x, y, scale)
|
||||
out_ref = x * scale + y
|
||||
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
|
||||
assert out_triton.shape == out_ref.shape
|
||||
assert out_triton.dtype == out_ref.dtype
|
||||
assert torch.allclose(out_triton, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
@@ -153,7 +153,7 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
|
||||
"max_model_len": 1024,
|
||||
"compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"},
|
||||
"quantization": cur_case.quantization,
|
||||
"additional_config": {"npugraph_ex_config": {"enable": False}},
|
||||
"additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": False}},
|
||||
}
|
||||
gen_and_valid(
|
||||
runner_kwargs=runner_kwargs,
|
||||
@@ -171,7 +171,7 @@ def test_npugraph_ex_res_consistency(cur_case: LLMTestCase, monkeypatch):
|
||||
"quantization": cur_case.quantization,
|
||||
"max_model_len": 1024,
|
||||
"compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"},
|
||||
"additional_config": {"npugraph_ex_config": {"enable": True}},
|
||||
"additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": True}},
|
||||
}
|
||||
gen_and_valid(
|
||||
runner_kwargs=runner_kwargs,
|
||||
@@ -193,8 +193,8 @@ def test_npugraph_ex_with_static_kernel(cur_case: LLMTestCase, monkeypatch):
|
||||
"max_model_len": 1024,
|
||||
"compilation_config": {"cudagraph_capture_sizes": [4, 8], "cudagraph_mode": "FULL_DECODE_ONLY"},
|
||||
"additional_config": {
|
||||
"npugraph_ex_config": {
|
||||
"enable": True,
|
||||
"ascend_compilation_config": {
|
||||
"enable_npugraph_ex": True,
|
||||
"enable_static_kernel": True,
|
||||
}
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user