[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)
### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("shape", "dtype", "scale"),
|
||||
[
|
||||
((1, 2048), torch.float16, 1.25),
|
||||
((4000, 2048), torch.float16, 0.75),
|
||||
((4, 2048), torch.bfloat16, 1.0),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_muls_add_triton_correctness(shape, dtype, scale):
|
||||
"""compare the correctness of muls_add_triton with the PyTorch baseline implementation."""
|
||||
init_device_properties_triton()
|
||||
device = "npu"
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(*shape, dtype=dtype, device=device)
|
||||
y = torch.randn(*shape, dtype=dtype, device=device)
|
||||
|
||||
out_triton = muls_add_triton(x, y, scale)
|
||||
out_ref = x * scale + y
|
||||
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
|
||||
assert out_triton.shape == out_ref.shape
|
||||
assert out_triton.dtype == out_ref.dtype
|
||||
assert torch.allclose(out_triton, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
Reference in New Issue
Block a user