[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)
### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
57
vllm_ascend/ops/triton/muls_add.py
Normal file
57
vllm_ascend/ops/triton/muls_add.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
def muls_add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector.
|
||||
y_ptr, # *Pointer* to second input vector.
|
||||
output_ptr, # *Pointer* to output vector.
|
||||
scale, # Scale factor.
|
||||
n_elements, # Size of the vector.
|
||||
n_blocks, # Total number of blocks.
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_programs = tl.num_programs(axis=0)
|
||||
for block_id in range(pid, n_blocks, num_programs):
|
||||
block_start = block_id * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x * scale + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def muls_add_triton(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor:
|
||||
assert x.shape == y.shape, "Input tensors must have the same shape."
|
||||
hidden_size = x.shape[-1]
|
||||
|
||||
n_elements = x.numel()
|
||||
output = torch.empty_like(x)
|
||||
|
||||
# Determine the number of vector cores available
|
||||
num_cores = get_vectorcore_num()
|
||||
|
||||
# Define block size
|
||||
BLOCK_SIZE = max(hidden_size // 2, 1024)
|
||||
|
||||
# Calculate the number of programs to launch
|
||||
num_blocks = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
num_programs = min(num_blocks, num_cores)
|
||||
|
||||
# Launch the Triton kernel
|
||||
muls_add_kernel[(num_programs,)](
|
||||
x,
|
||||
y,
|
||||
output,
|
||||
scale,
|
||||
n_elements,
|
||||
num_blocks,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user