### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -67,11 +67,9 @@ def matmul_bias_persistent_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
k_start = k * BLOCK_K
|
||||
# Calculate pointer offsets for x (row-major)
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] +
|
||||
k_start) * stride_xk
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
|
||||
# Calculate pointer offsets for y (row-major)
|
||||
y_ptrs = y_ptr + (rk[:, None] +
|
||||
k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
|
||||
# Create masks to prevent out-of-bounds access
|
||||
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
|
||||
@@ -89,14 +87,12 @@ def matmul_bias_persistent_kernel(
|
||||
# Load bias values (broadcast to all rows)
|
||||
bias_ptrs = bias_ptr + rn * stride_bias
|
||||
bias_mask = rn < N
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask,
|
||||
other=0.0).to(tl.float32)
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
|
||||
# Add bias to accumulator (automatic broadcasting)
|
||||
acc += bias_vals[None, :]
|
||||
|
||||
# Calculate output pointer positions
|
||||
out_ptrs = output_ptr + rm[:,
|
||||
None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
|
||||
# Store result to global memory
|
||||
@@ -106,28 +102,28 @@ def matmul_bias_persistent_kernel(
|
||||
def matmul_persistent(x, y, bias=None):
|
||||
"""
|
||||
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
|
||||
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [K, N]
|
||||
bias: torch.Tensor, shape [N] or None
|
||||
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
|
||||
M, K = x.shape
|
||||
_, N = y.shape
|
||||
# Validate bias shape (if not None)
|
||||
if bias is not None:
|
||||
assert bias.dim() == 1, "bias must be a 1D tensor"
|
||||
assert y.shape[1] == bias.shape[
|
||||
0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
assert y.shape[1] == bias.shape[0], (
|
||||
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
)
|
||||
|
||||
# Allocate output tensor (same data type as x)
|
||||
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
|
||||
@@ -176,24 +172,24 @@ def matmul_persistent(x, y, bias=None):
|
||||
|
||||
@triton.jit
|
||||
def linear_persistent_kernel(
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
):
|
||||
# Get current program's 1D index (1D grid)
|
||||
pid = tl.program_id(0)
|
||||
@@ -226,18 +222,12 @@ def linear_persistent_kernel(
|
||||
k_mask = k_indices < K
|
||||
|
||||
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[
|
||||
None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs,
|
||||
mask=m_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
|
||||
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[
|
||||
None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs,
|
||||
mask=n_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
|
||||
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
|
||||
b_vals_transposed = tl.trans(b_vals)
|
||||
@@ -246,8 +236,7 @@ def linear_persistent_kernel(
|
||||
product = tl.dot(a_vals, b_vals_transposed)
|
||||
acc += product
|
||||
# Store result to output tensor c
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[
|
||||
None, :] * stride_cn
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
|
||||
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
|
||||
|
||||
|
||||
@@ -255,19 +244,18 @@ def linear_persistent(x, y):
|
||||
"""
|
||||
Implement matrix multiplication with Triton: x @ y^T
|
||||
Uses a fixed-size 1D grid
|
||||
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [N, K]
|
||||
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
|
||||
M, K = x.shape
|
||||
N, _ = y.shape
|
||||
@@ -283,9 +271,8 @@ def linear_persistent(x, y):
|
||||
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
# Set fixed 1D grid size
|
||||
grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size, )
|
||||
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size,)
|
||||
|
||||
# Launch kernel
|
||||
linear_persistent_kernel[grid](
|
||||
@@ -330,8 +317,7 @@ def bmm_batch_invariant(a, b, *, out=None):
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
@@ -392,7 +378,8 @@ def matmul_batch_invariant(a, b, *, out=None):
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def linear_batch_invariant(input_, weight, bias=None):
|
||||
|
||||
Reference in New Issue
Block a user