### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -67,11 +67,9 @@ def matmul_bias_persistent_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
k_start = k * BLOCK_K
|
||||
# Calculate pointer offsets for x (row-major)
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] +
|
||||
k_start) * stride_xk
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
|
||||
# Calculate pointer offsets for y (row-major)
|
||||
y_ptrs = y_ptr + (rk[:, None] +
|
||||
k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
|
||||
# Create masks to prevent out-of-bounds access
|
||||
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
|
||||
@@ -89,14 +87,12 @@ def matmul_bias_persistent_kernel(
|
||||
# Load bias values (broadcast to all rows)
|
||||
bias_ptrs = bias_ptr + rn * stride_bias
|
||||
bias_mask = rn < N
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask,
|
||||
other=0.0).to(tl.float32)
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
|
||||
# Add bias to accumulator (automatic broadcasting)
|
||||
acc += bias_vals[None, :]
|
||||
|
||||
# Calculate output pointer positions
|
||||
out_ptrs = output_ptr + rm[:,
|
||||
None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
|
||||
# Store result to global memory
|
||||
@@ -106,28 +102,28 @@ def matmul_bias_persistent_kernel(
|
||||
def matmul_persistent(x, y, bias=None):
|
||||
"""
|
||||
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
|
||||
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [K, N]
|
||||
bias: torch.Tensor, shape [N] or None
|
||||
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
|
||||
M, K = x.shape
|
||||
_, N = y.shape
|
||||
# Validate bias shape (if not None)
|
||||
if bias is not None:
|
||||
assert bias.dim() == 1, "bias must be a 1D tensor"
|
||||
assert y.shape[1] == bias.shape[
|
||||
0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
assert y.shape[1] == bias.shape[0], (
|
||||
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
)
|
||||
|
||||
# Allocate output tensor (same data type as x)
|
||||
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
|
||||
@@ -176,24 +172,24 @@ def matmul_persistent(x, y, bias=None):
|
||||
|
||||
@triton.jit
|
||||
def linear_persistent_kernel(
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
):
|
||||
# Get current program's 1D index (1D grid)
|
||||
pid = tl.program_id(0)
|
||||
@@ -226,18 +222,12 @@ def linear_persistent_kernel(
|
||||
k_mask = k_indices < K
|
||||
|
||||
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[
|
||||
None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs,
|
||||
mask=m_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
|
||||
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[
|
||||
None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs,
|
||||
mask=n_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
|
||||
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
|
||||
b_vals_transposed = tl.trans(b_vals)
|
||||
@@ -246,8 +236,7 @@ def linear_persistent_kernel(
|
||||
product = tl.dot(a_vals, b_vals_transposed)
|
||||
acc += product
|
||||
# Store result to output tensor c
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[
|
||||
None, :] * stride_cn
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
|
||||
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
|
||||
|
||||
|
||||
@@ -255,19 +244,18 @@ def linear_persistent(x, y):
|
||||
"""
|
||||
Implement matrix multiplication with Triton: x @ y^T
|
||||
Uses a fixed-size 1D grid
|
||||
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [N, K]
|
||||
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
|
||||
M, K = x.shape
|
||||
N, _ = y.shape
|
||||
@@ -283,9 +271,8 @@ def linear_persistent(x, y):
|
||||
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
# Set fixed 1D grid size
|
||||
grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size, )
|
||||
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size,)
|
||||
|
||||
# Launch kernel
|
||||
linear_persistent_kernel[grid](
|
||||
@@ -330,8 +317,7 @@ def bmm_batch_invariant(a, b, *, out=None):
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
@@ -392,7 +378,8 @@ def matmul_batch_invariant(a, b, *, out=None):
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def linear_batch_invariant(input_, weight, bias=None):
|
||||
|
||||
@@ -86,8 +86,7 @@ def mean_dim(
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert -input_.ndim <= dim < input_.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions")
|
||||
assert -input_.ndim <= dim < input_.ndim, f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions"
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
@@ -123,7 +122,7 @@ def mean_dim(
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
output_shape = shape[:dim] + shape[dim + 1 :]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input_.device)
|
||||
@@ -135,7 +134,7 @@ def mean_dim(
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
grid = (M * K,)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
@@ -165,13 +164,10 @@ def mean_batch_invariant(
|
||||
if len(dim) == 1:
|
||||
return mean_dim(input_, dim[0], keepdim=keepdim)
|
||||
else:
|
||||
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32
|
||||
}, ("only float types supported for now")
|
||||
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32}, "only float types supported for now"
|
||||
if len(dim) == 0:
|
||||
dim = list(range(input_.ndim))
|
||||
n_elems = 1
|
||||
for d in dim:
|
||||
n_elems *= input_.shape[d]
|
||||
return torch.sum(input_, dim=dim, keepdim=keepdim,
|
||||
dtype=torch.float32).to(dtype
|
||||
or input_.dtype) / n_elems
|
||||
return torch.sum(input_, dim=dim, keepdim=keepdim, dtype=torch.float32).to(dtype or input_.dtype) / n_elems
|
||||
|
||||
@@ -100,8 +100,8 @@ def rms_norm(
|
||||
"""
|
||||
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||
assert input_.shape[-1] == weight.shape[0], (
|
||||
f"Input last dimension ({input_.shape[-1]}) must match "
|
||||
f"weight dimension ({weight.shape[0]})")
|
||||
f"Input last dimension ({input_.shape[-1]}) must match weight dimension ({weight.shape[0]})"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input_.shape
|
||||
@@ -113,10 +113,9 @@ def rms_norm(
|
||||
|
||||
output = torch.empty_like(input_2d, dtype=input_.dtype)
|
||||
BLOCK_SIZE = 1024
|
||||
max_grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"]
|
||||
max_grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"]
|
||||
|
||||
grid = (min(n_rows, max_grid_size), )
|
||||
grid = (min(n_rows, max_grid_size),)
|
||||
|
||||
_rms_norm_kernel[grid](
|
||||
input_2d,
|
||||
|
||||
Reference in New Issue
Block a user