[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -26,8 +26,7 @@ def _swiglu_quant_kernel(
else:
gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN)
gl_mask = gl_offsets < NUM_EXPERTS
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask,
other=0).to(tl.int32)
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32)
total_rows = tl.sum(group_list)
block_size = (total_rows - 1) // NUM_CORES + 1
@@ -41,14 +40,8 @@ def _swiglu_quant_kernel(
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x,
offsets=(0, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x2 = tl.extract_slice(cur_x,
offsets=(HALF_COLS, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
out = x1 * tl.sigmoid(x1) * x2
# quant
@@ -57,20 +50,13 @@ def _swiglu_quant_kernel(
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out,
offsets=(col_blk_idx, ),
sizes=(COL_BLOCK_SIZE, ),
strides=(1, ))
tmp_out = (tmp_out.to(tl.float32) / scale).to(
x_ptr.dtype.element_ty)
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
o_offsets = (row_idx * HALF_COLS + col_blk_idx +
tl.arange(0, COL_BLOCK_SIZE))
o_offsets = row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)
mask = (col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)) < HALF_COLS
tl.store(out_ptr + o_offsets,
tmp_out.to(out_ptr.dtype.element_ty),
mask=mask)
tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty), mask=mask)
else:
# store out
o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS)
@@ -80,12 +66,11 @@ def _swiglu_quant_kernel(
def swiglu_quant(x, group_list, group_list_type, need_quant=True):
# group_list_type must be 0 cusum or 1 count
if group_list_type not in [0, 1]:
raise ValueError(
f"group_list_type must be 0 or 1, but got {group_list_type}")
raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}")
s, h = x.shape
out_dtype = torch.int8 if need_quant else x.dtype
out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device)
scale = torch.empty((s, ), dtype=torch.float32, device=x.device)
scale = torch.empty((s,), dtype=torch.float32, device=x.device)
num_experts = group_list.shape[0]
# ub must be 32-byte aligned on npu
if group_list.dtype == torch.int64:
@@ -93,12 +78,10 @@ def swiglu_quant(x, group_list, group_list_type, need_quant=True):
elif group_list.dtype == torch.int32:
num_experts_algin = (num_experts + 15) // 16 * 16
else:
raise ValueError(
f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}"
)
raise ValueError(f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}")
num_vectorcore = get_vectorcore_num()
_swiglu_quant_kernel[(num_vectorcore, )](
_swiglu_quant_kernel[(num_vectorcore,)](
x,
group_list,
out,

View File

@@ -67,11 +67,9 @@ def matmul_bias_persistent_kernel(
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_start = k * BLOCK_K
# Calculate pointer offsets for x (row-major)
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] +
k_start) * stride_xk
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
# Calculate pointer offsets for y (row-major)
y_ptrs = y_ptr + (rk[:, None] +
k_start) * stride_yk + rn[None, :] * stride_yn
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
# Create masks to prevent out-of-bounds access
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
@@ -89,14 +87,12 @@ def matmul_bias_persistent_kernel(
# Load bias values (broadcast to all rows)
bias_ptrs = bias_ptr + rn * stride_bias
bias_mask = rn < N
bias_vals = tl.load(bias_ptrs, mask=bias_mask,
other=0.0).to(tl.float32)
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
# Add bias to accumulator (automatic broadcasting)
acc += bias_vals[None, :]
# Calculate output pointer positions
out_ptrs = output_ptr + rm[:,
None] * stride_outm + rn[None, :] * stride_outn
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
# Store result to global memory
@@ -106,28 +102,28 @@ def matmul_bias_persistent_kernel(
def matmul_persistent(x, y, bias=None):
"""
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [K, N]
bias: torch.Tensor, shape [N] or None
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[
0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
M, K = x.shape
_, N = y.shape
# Validate bias shape (if not None)
if bias is not None:
assert bias.dim() == 1, "bias must be a 1D tensor"
assert y.shape[1] == bias.shape[
0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
assert y.shape[1] == bias.shape[0], (
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
)
# Allocate output tensor (same data type as x)
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
@@ -176,24 +172,24 @@ def matmul_persistent(x, y, bias=None):
@triton.jit
def linear_persistent_kernel(
a_ptr, # Pointer to tensor a, shape [M, K]
b_ptr, # Pointer to tensor b, shape [N, K]
c_ptr, # Pointer to output tensor c, shape [M, N]
M, # Number of rows in tensor a
N, # Number of rows in tensor b (number of columns in output c)
K, # Number of columns in both tensor a and tensor b
stride_am, # Stride of tensor a along dimension M (typically K)
stride_ak, # Stride of tensor a along dimension K (typically 1)
stride_bn, # Stride of tensor b along dimension N (typically K)
stride_bk, # Stride of tensor b along dimension K (typically 1)
stride_cm, # Stride of tensor c along dimension M (typically N)
stride_cn, # Stride of tensor c along dimension N (typically 1)
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr, # Block size for K dimension
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
a_ptr, # Pointer to tensor a, shape [M, K]
b_ptr, # Pointer to tensor b, shape [N, K]
c_ptr, # Pointer to output tensor c, shape [M, N]
M, # Number of rows in tensor a
N, # Number of rows in tensor b (number of columns in output c)
K, # Number of columns in both tensor a and tensor b
stride_am, # Stride of tensor a along dimension M (typically K)
stride_ak, # Stride of tensor a along dimension K (typically 1)
stride_bn, # Stride of tensor b along dimension N (typically K)
stride_bk, # Stride of tensor b along dimension K (typically 1)
stride_cm, # Stride of tensor c along dimension M (typically N)
stride_cn, # Stride of tensor c along dimension N (typically 1)
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr, # Block size for K dimension
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
):
# Get current program's 1D index (1D grid)
pid = tl.program_id(0)
@@ -226,18 +222,12 @@ def linear_persistent_kernel(
k_mask = k_indices < K
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[
None, :] * stride_ak
a_vals = tl.load(a_ptrs,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0)
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[
None, :] * stride_bk
b_vals = tl.load(b_ptrs,
mask=n_mask[:, None] & k_mask[None, :],
other=0.0)
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
b_vals_transposed = tl.trans(b_vals)
@@ -246,8 +236,7 @@ def linear_persistent_kernel(
product = tl.dot(a_vals, b_vals_transposed)
acc += product
# Store result to output tensor c
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[
None, :] * stride_cn
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
@@ -255,19 +244,18 @@ def linear_persistent(x, y):
"""
Implement matrix multiplication with Triton: x @ y^T
Uses a fixed-size 1D grid
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [N, K]
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[
1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
M, K = x.shape
N, _ = y.shape
@@ -283,9 +271,8 @@ def linear_persistent(x, y):
num_blocks_n = triton.cdiv(N, BLOCK_N)
# Set fixed 1D grid size
grid_size = driver.active.utils.get_device_properties(
torch.npu.current_device())["num_vectorcore"] // 2
grid = (grid_size, )
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
grid = (grid_size,)
# Launch kernel
linear_persistent_kernel[grid](
@@ -330,8 +317,7 @@ def bmm_batch_invariant(a, b, *, out=None):
return out
return result
else:
raise ValueError(f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}")
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
def addmm_batch_invariant(bias, a, b):
@@ -392,7 +378,8 @@ def matmul_batch_invariant(a, b, *, out=None):
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"got shapes {a.shape} and {b.shape}")
f"got shapes {a.shape} and {b.shape}"
)
def linear_batch_invariant(input_, weight, bias=None):

View File

@@ -86,8 +86,7 @@ def mean_dim(
Tensor with mean values along specified dimension
"""
# Validate inputs
assert -input_.ndim <= dim < input_.ndim, (
f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions")
assert -input_.ndim <= dim < input_.ndim, f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions"
# Handle negative dim
if dim < 0:
@@ -123,7 +122,7 @@ def mean_dim(
output_shape = shape.copy()
output_shape[dim] = 1
else:
output_shape = shape[:dim] + shape[dim + 1:]
output_shape = shape[:dim] + shape[dim + 1 :]
# Create output tensor
output = torch.empty(output_shape, dtype=dtype, device=input_.device)
@@ -135,7 +134,7 @@ def mean_dim(
output_2d = output.reshape(M, K)
# Launch kernel
grid = (M * K, )
grid = (M * K,)
BLOCK_SIZE = 1024
mean_kernel[grid](
@@ -165,13 +164,10 @@ def mean_batch_invariant(
if len(dim) == 1:
return mean_dim(input_, dim[0], keepdim=keepdim)
else:
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32
}, ("only float types supported for now")
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32}, "only float types supported for now"
if len(dim) == 0:
dim = list(range(input_.ndim))
n_elems = 1
for d in dim:
n_elems *= input_.shape[d]
return torch.sum(input_, dim=dim, keepdim=keepdim,
dtype=torch.float32).to(dtype
or input_.dtype) / n_elems
return torch.sum(input_, dim=dim, keepdim=keepdim, dtype=torch.float32).to(dtype or input_.dtype) / n_elems

View File

@@ -100,8 +100,8 @@ def rms_norm(
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input_.shape[-1] == weight.shape[0], (
f"Input last dimension ({input_.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})")
f"Input last dimension ({input_.shape[-1]}) must match weight dimension ({weight.shape[0]})"
)
# Flatten all dimensions except the last one
original_shape = input_.shape
@@ -113,10 +113,9 @@ def rms_norm(
output = torch.empty_like(input_2d, dtype=input_.dtype)
BLOCK_SIZE = 1024
max_grid_size = driver.active.utils.get_device_properties(
torch.npu.current_device())["num_vectorcore"]
max_grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"]
grid = (min(n_rows, max_grid_size), )
grid = (min(n_rows, max_grid_size),)
_rms_norm_kernel[grid](
input_2d,

View File

@@ -9,7 +9,6 @@
# ruff: noqa: E501
# mypy: ignore-errors
import warnings
from typing import Optional
import torch
from einops import rearrange
@@ -25,22 +24,20 @@ from .utils import input_guard
from .wy_fast import recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=torch.float32)
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
@@ -75,20 +72,21 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
@@ -110,17 +108,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
@@ -186,41 +186,39 @@ def chunk_gated_delta_rule(q: torch.Tensor,
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
stacklevel=2,
)
q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g))
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel
)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
return o, final_state
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state

View File

@@ -8,23 +8,24 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp
_CONDITIONS = ("seq7168", )
_CONDITIONS = ("seq7168",)
@triton.heuristics({
"USE_G": lambda args: args["g"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
})
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
@@ -85,28 +86,20 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
if USE_INITIAL_STATE:
h0_ptr = h0 + i_nh * K * V
ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1,
other=0.0).to(tl.float32)
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, other=0.0).to(tl.float32)
ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2,
other=0.0).to(tl.float32)
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, other=0.0).to(tl.float32)
# main recurrence
for i_t in range(NT):
h_base = h + (boh + i_t) * H * K * V + i_h * K * V
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1),
(128, 64), (1, 0))
tl.store(p_h1_bv1,
b_h1_bv1.to(p_h1_bv1.dtype.element_ty),
boundary_check=(0, 1))
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
tl.store(p_h1_bv1, b_h1_bv1.to(p_h1_bv1.dtype.element_ty), boundary_check=(0, 1))
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2),
(128, 64), (1, 0))
tl.store(p_h1_bv2,
b_h1_bv2.to(p_h1_bv2.dtype.element_ty),
boundary_check=(0, 1))
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
tl.store(p_h1_bv2, b_h1_bv2.to(p_h1_bv2.dtype.element_ty), boundary_check=(0, 1))
offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None]
offs_k_wv = tl.arange(0, 128)[None, :]
@@ -117,8 +110,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_w = tl.load(ptr_w, mask=mask_w, other=0.0)
k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT),
(128, BT), (0, 1))
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), (128, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
v_new_base = v_new + bos * H * V + i_h * V
@@ -144,12 +136,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype))
if SAVE_NEW_VALUE:
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
(i_t * BT, v_start1), (BT, 64),
(1, 0))
tl.store(p_v_new1,
b_v_new1.to(p_v_new1.dtype.element_ty),
boundary_check=(0, 1))
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start1), (BT, 64), (1, 0))
tl.store(p_v_new1, b_v_new1.to(p_v_new1.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
b_v_new1 = b_v_new1 * b_g[:, None]
@@ -165,12 +153,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype))
if SAVE_NEW_VALUE:
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
(i_t * BT, v_start2), (BT, 64),
(1, 0))
tl.store(p_v_new2,
b_v_new2.to(p_v_new2.dtype.element_ty),
boundary_check=(0, 1))
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start2), (BT, 64), (1, 0))
tl.store(p_v_new2, b_v_new2.to(p_v_new2.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
b_v_new2 = b_v_new2 * b_g[:, None]
@@ -183,29 +167,23 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
if STORE_FINAL_STATE:
ht_ptr = ht + i_nh * K * V
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1),
(128, 64), (1, 0))
tl.store(p_ht1_bv1,
b_h1_bv1.to(p_ht1_bv1.dtype.element_ty),
boundary_check=(0, 1))
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
tl.store(p_ht1_bv1, b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), boundary_check=(0, 1))
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2),
(128, 64), (1, 0))
tl.store(p_ht1_bv2,
b_h1_bv2.to(p_ht1_bv2.dtype.element_ty),
boundary_check=(0, 1))
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
tl.store(p_ht1_bv2, b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
g: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
@@ -213,8 +191,7 @@ def chunk_gated_delta_rule_fwd_h(
H = u.shape[-2]
BT = chunk_size
chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None else None)
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
@@ -227,8 +204,7 @@ def chunk_gated_delta_rule_fwd_h(
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V)
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
if output_final_state else None)
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
v_new = torch.empty_like(u) if save_new_value else None
g = g.transpose(1, 2).contiguous()

View File

@@ -9,7 +9,6 @@
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
@@ -17,11 +16,13 @@ from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_offsets, safe_exp
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.jit(do_not_specialize=['T'])
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
@@ -48,8 +49,7 @@ def chunk_fwd_kernel_o(
T_max = T
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
@@ -71,12 +71,9 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1),
(i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K),
(i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1),
(i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
@@ -102,10 +99,8 @@ def chunk_fwd_kernel_o(
m_A = o_i[:, None] >= o_i[None, :]
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
@@ -119,9 +114,9 @@ def chunk_fwd_o(
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
g: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
@@ -129,7 +124,7 @@ def chunk_fwd_o(
BT = chunk_size
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
if cu_seqlens is None:
@@ -141,7 +136,7 @@ def chunk_fwd_o(
)
def grid(meta):
return (triton.cdiv(V, meta['BV']), N * H)
return (triton.cdiv(V, meta["BV"]), N * H)
g = g.transpose(1, 2).contiguous()
chunk_fwd_kernel_o[grid](

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
@@ -16,11 +15,13 @@ from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices, safe_exp
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
'USE_G': lambda args: args['g_cumsum'] is not None,
})
@triton.jit(do_not_specialize=['T'])
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"USE_G": lambda args: args["g_cumsum"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta, # [H, B, T]
@@ -44,10 +45,11 @@ def chunk_scaled_dot_kkt_fwd_kernel(
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to(
tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t_i * 2).to(tl.int32),
tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
@@ -55,39 +57,37 @@ def chunk_scaled_dot_kkt_fwd_kernel(
o_t = tl.arange(0, BT)
o_t_fp32 = o_t.to(tl.float32)
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ),
(i_t * BT, ), (BT, ), (0, ))
b_beta = tl.load(p_beta, boundary_check=(0, ))
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ),
(1, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A *= safe_exp(b_g_diff)
b_A *= b_beta[:, None]
b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
Compute beta * K * K^T.
@@ -117,8 +117,7 @@ def chunk_scaled_dot_kkt_fwd(
BT = chunk_size
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.cpu()
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
if cu_seqlens is not None else None)
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
chunk_indices = chunk_indices.npu()
cu_seqlens = cu_seqlens.npu()
else:

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
@@ -16,11 +15,10 @@ from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices
@triton.heuristics({
'HAS_SCALE': lambda args: args['scale'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.jit(do_not_specialize=['T'])
@triton.heuristics(
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
@@ -41,20 +39,19 @@ def chunk_local_cumsum_scalar_kernel(
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
if IS_VARLEN:
i_s, i_block = tl.load(chunk_indices + i_block * 2).to(
tl.int32), tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_s).to(
tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
i_s, i_block = (
tl.load(chunk_indices + i_block * 2).to(tl.int32),
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if HEAD_FIRST:
ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1),
(0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1),
(0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32)
ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE))
b_s = tl.trans(b_s, (2, 0, 1))
b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
@@ -63,11 +60,9 @@ def chunk_local_cumsum_scalar_kernel(
b_o = tl.trans(b_o, (2, 0, 1))
b_o = tl.reshape(b_o, (H, BLOCK_T))
else:
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1),
(i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1),
(i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32)
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
b_s = tl.trans(b_s, (1, 0, 2))
b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
@@ -76,7 +71,7 @@ def chunk_local_cumsum_scalar_kernel(
b_o = tl.trans(b_o, (1, 0, 2))
b_o = tl.reshape(b_o, (BLOCK_T, H))
tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, ))
tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0,))
return
@@ -85,61 +80,64 @@ def chunk_local_cumsum_scalar(
chunk_size,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.Tensor] = torch.float,
output_dtype: torch.Tensor | None = torch.float,
):
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
assert chunk_size == 2**(chunk_size.bit_length() -
1), "chunk_size must be a power of 2"
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
block_indices = prepare_chunk_indices(
cu_seqlens,
chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(
T, OPTIM_BLOCK_SIZE)
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (num_blocks, B)
chunk_local_cumsum_scalar_kernel[grid](s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=block_indices,
T=T,
B=B,
H=H,
BLOCK_T=OPTIM_BLOCK_SIZE,
CHUNK_SIZE=chunk_size,
HEAD_FIRST=head_first,
REVERSE=reverse,
num_warps=8,
num_stages=3)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=block_indices,
T=T,
B=B,
H=H,
BLOCK_T=OPTIM_BLOCK_SIZE,
CHUNK_SIZE=chunk_size,
HEAD_FIRST=head_first,
REVERSE=reverse,
num_warps=8,
num_stages=3,
)
return g
def chunk_local_cumsum(g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs) -> torch.Tensor:
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
assert g.shape[
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype)
return chunk_local_cumsum_scalar(
g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype,
)
else:
raise ValueError(f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise")
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -31,29 +31,30 @@ def fused_qkvzba_split_reshape_cat_kernel(
BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
q_end: tl.constexpr = HEAD_QK
blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(0, q_end))
blk_q_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(0, q_end)
k_end: tl.constexpr = q_end + HEAD_QK
blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end))
blk_k_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end)
v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end))
blk_v_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end)
z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end))
blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
i_qk * HEAD_QK + tl.arange(0, HEAD_QK))
blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK +
tl.arange(0, HEAD_QK))
blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
NUM_HEADS_QK * HEAD_QK * 2 +
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V +
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
blk_z_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end)
blk_q_st_ptr = mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)
blk_k_st_ptr = (
mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)
)
blk_v_st_ptr = (
mixed_qkv
+ i_bs * NUM_HEADS_QK * QKV_DIM_T
+ NUM_HEADS_QK * HEAD_QK * 2
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
)
blk_z_st_ptr = (
z
+ i_bs * NUM_HEADS_V * HEAD_V
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
)
tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
@@ -66,8 +67,7 @@ def fused_qkvzba_split_reshape_cat_kernel(
tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
for i in tl.static_range(b_end, a_end):
blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_a_st_ptr = (a + i_bs * NUM_HEADS_V +
i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end))
blk_a_st_ptr = a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))

View File

@@ -15,8 +15,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
rindex = tl.arange(0, N)[None, :]
@@ -24,8 +23,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
xs = tl.load(X + (rindex + N * row_idx), mask=xmask,
other=0.0).to(tl.float32)
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
square = xs * xs
square_sum = tl.sum(square, 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
@@ -33,9 +31,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: torch.dtype | None = None):
def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None):
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
# allocate output
@@ -56,7 +52,7 @@ def l2norm_fwd(x: torch.Tensor,
num_core = get_vectorcore_num()
main_bs = triton.cdiv(T, num_core)
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
grid = (num_core, )
grid = (num_core,)
l2norm_fwd_kernel2_loop[grid](
X=x,
Y=y,

View File

@@ -12,10 +12,12 @@ from vllm.triton_utils import tl, triton
MAX_CORES = 65535
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
@@ -49,13 +51,10 @@ def layer_norm_fwd_kernel(
n_iters = n_iters + 1
for i in tl.range(n_iters):
X_base = X + (i * BLOCK_ROWS *
stride_x_row) + row * stride_x_row + group * N
Y_base = Y + (i * BLOCK_ROWS *
stride_y_row) + row * stride_y_row + group * N
X_base = X + (i * BLOCK_ROWS * stride_x_row) + row * stride_x_row + group * N
Y_base = Y + (i * BLOCK_ROWS * stride_y_row) + row * stride_y_row + group * N
if HAS_Z:
Z_base = Z + (i * BLOCK_ROWS *
stride_z_row) + row * stride_z_row + group * N
Z_base = Z + (i * BLOCK_ROWS * stride_z_row) + row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean_base = Mean + (i * BLOCK_ROWS) + group * M
Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
@@ -64,17 +63,17 @@ def layer_norm_fwd_kernel(
B_base = B + group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X_base + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean_base + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd_base + row, rstd)
@@ -112,26 +111,24 @@ def _layer_norm_fwd(
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N, )
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
if not is_rms_norm else None)
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
@@ -160,7 +157,6 @@ def _layer_norm_fwd(
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,

View File

@@ -14,7 +14,7 @@ import os
import torch
from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
div = tldevice.fast_dividef
exp = tldevice.fast_expf
log = tldevice.fast_logf
@@ -31,17 +31,15 @@ else:
log2 = tl.log2
@triton.heuristics({
'USE_INITIAL_STATE':
lambda args: args['h0'] is not None,
'IS_VARLEN':
lambda args: args['cu_seqlens'] is not None,
"IS_CONTINUOUS_BATCHING":
lambda args: args['ssm_state_indices'] is not None,
"IS_SPEC_DECODING":
lambda args: args['num_accepted_tokens'] is not None,
})
@triton.jit(do_not_specialize=['N', 'T'])
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_recurrent_gated_delta_rule_fwd_kernel(
q,
k,
@@ -70,8 +68,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
IS_BETA_HEADWISE: tl.
constexpr, # whether beta is headwise vector or scalar,
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
@@ -82,8 +79,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
all = T
T = eos - bos
else:
@@ -108,8 +104,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_init_state_token
p_h0 = (
h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token
)
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
@@ -164,18 +161,21 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_final_state_token
p_ht = (
ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token
)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@triton.heuristics({
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
})
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
@@ -245,8 +245,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
idx = tl.load(h0_indices + i_n)
# if idx >= 0:
tmp0 = tl.where(idx < 0, 0, idx)
p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
p_h0 = h0_source + tmp0 * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
temp2 = tl.zeros_like(temp1)
value0 = tl.where(idx < 0, temp2, temp1)
@@ -314,8 +313,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
p_h0 = h0_source + idx * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
@@ -350,7 +348,7 @@ def fused_sigmoid_gating_delta_rule_update(
num_warps = 1
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
@@ -66,9 +65,13 @@ def solve_tril_16x16_kernel(
offs_cols_in_block = tl.arange(0, 16)
# 2 Calculate the pointer of each element
ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o +
offs_rows_in_block[:, None] * H * BT +
offs_cols_in_block[None, :])
ptr_A_subrec16 = (
A
+ row_start_o * H * BT
+ col_start_o
+ offs_rows_in_block[:, None] * H * BT
+ offs_cols_in_block[None, :]
)
# 3 Create a mask to prevent out-of-bounds access
global_rows = row_start_o + offs_rows_in_block[:, None]
@@ -76,14 +79,14 @@ def solve_tril_16x16_kernel(
load_mask = (global_rows < T) & (global_cols < BT)
# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask,
other=0.0).to(tl.float32)
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
b_A = tl.insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
sizes=[1, 16, 16],
strides=[1, 1, 1])
strides=[1, 1, 1],
)
local_ori_A = tl.trans(b_A, (1, 0, 2))
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
@@ -97,9 +100,7 @@ def solve_tril_16x16_kernel(
# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0),
(1, 16 * N_BLOCKS),
(16 * N_BLOCKS, 1))
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
@@ -107,34 +108,27 @@ def solve_tril_16x16_kernel(
b_a = b_a + dot_product
b_a_new_expanded = b_a[:, None, :]
b_A = tl.insert_slice(ful=b_A,
sub=b_a_new_expanded,
offsets=[0, i, 0],
sizes=[N_BLOCKS, 1, 16],
strides=[1, 1, 1])
b_A = tl.insert_slice(
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
)
on_diagonal = (rows == cols)
on_diagonal = rows == cols
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0),
(N_BLOCKS * 16, 16), (1, 0))
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0))
# 1 Create in-block offset
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
offs_cols_to_store = tl.arange(0, 16)
# 2 Calculate the pointer of each element
p_Ai = (Ad + base_t * H * 16 + 0 +
offs_rows_to_store[:, None] * H * 16 +
offs_cols_to_store[None, :])
p_Ai = Ad + base_t * H * 16 + 0 + offs_rows_to_store[:, None] * H * 16 + offs_cols_to_store[None, :]
# 3 Create a mask to prevent out-of-bounds access, only check rows
global_store_rows = base_t + offs_rows_to_store[:, None]
store_mask = global_store_rows < T
# 4 use mask to save data safely
tl.store(p_Ai,
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=store_mask)
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=store_mask)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@@ -169,18 +163,12 @@ def merge_16x16_to_32x32_inverse_kernel(
Ad += (bos * H + i_h) * 16
Ai += (bos * H + i_h) * 32
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
(16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
(16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
(16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
@@ -313,26 +301,20 @@ def merge_16x16_to_64x64_inverse_kernel(
offs_n = tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai,
Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=mask_store)
tl.store(ptr_Ai, Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# store Ai_22_32 to (i_t * 64 + 32, 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
offs_n = 32 + tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai,
Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=mask_store)
tl.store(ptr_Ai, Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# store Ai_21_32 to (i_t * 64 + 32, 32)
offs_n = tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai,
Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=mask_store)
tl.store(ptr_Ai, Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
offs_m = i_t * 64 + tl.arange(0, 32)
@@ -345,7 +327,7 @@ def merge_16x16_to_64x64_inverse_kernel(
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
@@ -367,19 +349,12 @@ def solve_tril(
assert A.shape[-1] in [16, 32, 64]
B, T, H, BT = A.shape
Ad = torch.empty(B,
T,
H,
16,
device=A.device,
dtype=torch.float if BT != 16 else output_dtype)
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
LARGE_BLOCK_T = 608 * 2
chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
if cu_seqlens is not None else None)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(
T, LARGE_BLOCK_T)
chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
solve_tril_16x16_kernel[NT, B * H](
A=A,
@@ -398,10 +373,8 @@ def solve_tril(
return Ad
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
merge_fn = (merge_16x16_to_32x32_inverse_kernel
if BT == 32 else merge_16x16_to_64x64_inverse_kernel)
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
if cu_seqlens is not None else None)
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
merge_fn[NT, B * H](

View File

@@ -9,7 +9,7 @@
# ruff: noqa: E501
import contextlib
import functools
from typing import Callable
from collections.abc import Callable
import torch
from vllm.triton_utils import tl, triton
@@ -19,38 +19,24 @@ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
indices = torch.cat([
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
1).to(cu_seqlens)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
return torch.cat([
cu_seqlens.new_tensor([0]),
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
]).cumsum(-1)
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
def input_guard(
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else
i.contiguous() for i in args)
contiguous_kwargs = {
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
for k, v in kwargs.items()
}
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
tensor = None
for arg in args:

View File

@@ -9,7 +9,6 @@
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional, Tuple
import torch
from vllm.triton_utils import tl, triton
@@ -17,23 +16,39 @@ from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.jit(do_not_specialize=['T'])
def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
T, H: tl.constexpr, Hg: tl.constexpr,
K: tl.constexpr, V: tl.constexpr,
BT: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, IS_VARLEN: tl.constexpr):
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_max = T
i_t_o = tl.program_id(0)
for i_bh in range(H):
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t_o * 2).to(
tl.int32), tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t_o * 2).to(tl.int32),
tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
@@ -44,7 +59,7 @@ def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
offs_t_2d = global_offs_t[:, None]
offs_bt = tl.arange(0, BT)[None, :]
ptr_A = (A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1)
ptr_A = A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1
mask_A = mask_t[:, None]
b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
@@ -58,29 +73,25 @@ def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
offs_v = i_v * BV + tl.arange(0, BV)[None, :]
mask_v = (mask_t[:, None]) & (offs_v < V)
ptr_v = (v + (bos * H + i_h) * V + offs_t_2d * (H * V) +
offs_v * 1)
ptr_v = v + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32)
b_vb = (b_v * b_beta[:, None])
b_vb = b_v * b_beta[:, None]
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
ptr_u = (u + (bos * H + i_h) * V + offs_t_2d * (H * V) +
offs_v * 1)
ptr_u = u + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v)
for i_k in range(tl.cdiv(K, BK)):
offs_k = i_k * BK + tl.arange(0, BK)[None, :]
mask_k = (mask_t[:, None]) & (offs_k < K)
ptr_k = (k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d *
(Hg * K) + offs_k * 1)
ptr_k = k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * (Hg * K) + offs_k * 1
b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32)
b_kb = (b_k * b_beta[:, None] * b_g[:, None])
b_kb = b_k * b_beta[:, None] * b_g[:, None]
b_w = tl.dot(b_A, b_kb)
ptr_w = (w + (bos * H + i_h) * K + offs_t_2d * (H * K) +
offs_k * 1)
ptr_w = w + (bos * H + i_h) * K + offs_t_2d * (H * K) + offs_k * 1
tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k)
@@ -90,14 +101,13 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) \
if cu_seqlens is not None else None
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64

View File

@@ -30,15 +30,12 @@ def fused_gdn_gating_kernel(
):
i_b, i_s = tl.program_id(0), tl.program_id(1)
for row_idx in range(0, ROW_ITER):
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(
0, BLK_BATCHES)
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
for col_idx in range(0, COL_ITER):
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = batch_off[:,
None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[
None, :]
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
head_mask = head_off < NUM_HEADS
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
@@ -48,17 +45,14 @@ def fused_gdn_gating_kernel(
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
softplus_x = tl.where(beta * x <= threshold,
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off,
blk_beta_output.to(beta_output.dtype.element_ty),
mask=mask)
tl.store(beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask)
def fused_gdn_gating_patch(
@@ -85,17 +79,13 @@ def fused_gdn_gating_patch(
progs = num_cores
FACTOR = 8 * num_heads
row_per_core = triton.cdiv(batch, num_cores)
BLK_BATCHES = triton.next_power_of_2(
triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) //
a.element_size()) // 2
BLK_BATCHES = (
triton.next_power_of_2(triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // a.element_size()) // 2
)
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1,
batch,
num_heads,
dtype=b.dtype,
device=b.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
grid = (progs, seq_len)
fused_gdn_gating_kernel[grid](

View File

@@ -9,6 +9,7 @@
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
@@ -131,11 +132,7 @@ def layer_norm_fwd_npu(
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
MAX_FUSED_SIZE = 65536 // x.element_size()
@@ -168,4 +165,4 @@ def layer_norm_fwd_npu(
IS_RMS_NORM=is_rms_norm,
# Remove multibuffer if not needed
)
return out, mean, rstd
return out, mean, rstd

View File

@@ -14,7 +14,6 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional
import torch
import triton # type: ignore
@@ -61,22 +60,19 @@ def split_qkv_rmsnorm_rope_kernel(
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
@@ -93,8 +89,7 @@ def split_qkv_rmsnorm_rope_kernel(
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
@@ -127,22 +122,19 @@ def split_qkv_rmsnorm_rope_kernel(
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
@@ -158,8 +150,7 @@ def split_qkv_rmsnorm_rope_kernel(
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
@@ -189,12 +180,8 @@ def split_qkv_rmsnorm_rope_kernel(
for _ in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0)
tl.store(v_ptr + output_offset + col_indices,
input_values,
mask=valid_mask)
input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask)
input_offset += input_offset_step
output_offset += output_offset_step
@@ -209,27 +196,18 @@ def split_qkv_rmsnorm_rope_impl(
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert KV_BLOCK_SIZE == head_dim
assert head_dim == KV_BLOCK_SIZE
assert q_hidden_size % kv_hidden_size == 0
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
batch_size = input.shape[0]
total_hidden_size = q_hidden_size + kv_hidden_size * 2
q_output = torch.empty(batch_size,
q_hidden_size,
device=input.device,
dtype=input.dtype)
k_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
v_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype)
k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
num_vectorcore = get_vectorcore_num()
assert num_vectorcore % n_cols == 0
@@ -271,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl_fake(
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Fake implementation for shape inference during Dynamo/AOT tracing.
# Note: sin and cos are not used in shape computation, but must be present in signature.
@@ -298,8 +276,10 @@ def split_qkv_rmsnorm_rope_impl_fake(
return q_output, k_output, v_output
direct_register_custom_op(op_name="qkv_rmsnorm_rope",
op_func=split_qkv_rmsnorm_rope_impl,
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(
op_name="qkv_rmsnorm_rope",
op_func=split_qkv_rmsnorm_rope_impl,
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1",
)

View File

@@ -7,24 +7,23 @@
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# mypy: ignore-errors
from typing import Any, Optional
from typing import Any
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
initial_states: torch.Tensor | None = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
final_states_out: torch.Tensor | None = None,
activation: str | None = "silu",
):
"""
x: (batch, dim, seqlen)
@@ -42,19 +41,14 @@ def causal_conv1d_ref(
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
@@ -66,13 +60,13 @@ def causal_conv1d_ref(
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
conv_states: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
metadata: Optional[Any] = None,
bias: torch.Tensor | None = None,
activation: str | None = "silu",
conv_states: torch.Tensor | None = None,
has_initial_state: torch.Tensor | None = None,
cache_indices: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None,
metadata: Any | None = None,
pad_slot_id: int = PAD_SLOT_ID,
):
"""
@@ -126,10 +120,10 @@ def causal_conv1d_fn(
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]][..., :(
width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
if has_initial_state[i] else None))
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None,
)
)
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
@@ -137,54 +131,50 @@ def causal_conv1d_fn(
@triton.jit
def _causal_conv1d_update_kernel_npu_tiled(
# Pointers
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr, # (num_cache_lines, dim, state_len)
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
o_ptr, # same shape as x_ptr
batch: tl.int32,
dim: tl.constexpr,
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
state_len: tl.constexpr, # effective state_len computed in wrapper
num_cache_lines: tl.constexpr,
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr, # <= 6
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
# tiling
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
B_TILE: tl.constexpr, # batch tile
T_CHUNK: tl.constexpr, # token chunk for state update
# Pointers
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr, # (num_cache_lines, dim, state_len)
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
o_ptr, # same shape as x_ptr
batch: tl.int32,
dim: tl.constexpr,
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
state_len: tl.constexpr, # effective state_len computed in wrapper
num_cache_lines: tl.constexpr,
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr, # <= 6
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
# tiling
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
B_TILE: tl.constexpr, # batch tile
T_CHUNK: tl.constexpr, # token chunk for state update
):
# program ids
pid_b = tl.program_id(0) # batch-tile id
@@ -197,37 +187,30 @@ def _causal_conv1d_update_kernel_npu_tiled(
# preload weights once per program (shared by B_TILE sequences)
w_base = w_ptr + idx_feats * stride_w_dim
# define to avoid "undefined" in branches
w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32)
if KERNEL_WIDTH >= 1:
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w,
other=0.0).to(tl.float32)
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 2:
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w,
other=0.0).to(tl.float32)
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 3:
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w,
other=0.0).to(tl.float32)
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 4:
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w,
other=0.0).to(tl.float32)
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 5:
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w,
other=0.0).to(tl.float32)
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 6:
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w,
other=0.0).to(tl.float32)
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
# bias vector once per program
if HAS_BIAS:
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w,
other=0.0).to(tl.float32)
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32)
else:
acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32)
acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
# token index vector for chunked copy
tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK]
@@ -241,36 +224,26 @@ def _causal_conv1d_update_kernel_npu_tiled(
# APC mapping (optional)
# -------------------------
if IS_APC_ENABLED:
conv_state_init = tl.load(initial_state_idx + b,
mask=lane_active,
other=0).to(tl.int32)
current_last_index = tl.load(block_idx_last_scheduled_token + b,
mask=lane_active,
other=0).to(tl.int32)
conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32)
current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32)
else:
conv_state_init = tl.full((), 0, tl.int32)
current_last_index = tl.full((), 0, tl.int32)
# input cache line
conv_states_input_coord = tl.load(conv_state_indices_ptr +
b * stride_state_indices +
conv_state_init,
mask=lane_active,
other=0).to(tl.int64)
conv_states_input_coord = tl.load(
conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0
).to(tl.int64)
if USE_PAD_SLOT:
lane_active = lane_active & (conv_states_input_coord
!= pad_slot_id)
lane_active = lane_active & (conv_states_input_coord != pad_slot_id)
# -------------------------
# varlen (optional): revise seqlen_run and state_len_run like original kernel does
# -------------------------
if IS_VARLEN:
qs = tl.load(query_start_loc_ptr + b, mask=lane_active,
other=0).to(tl.int64)
qe = tl.load(query_start_loc_ptr + (b + 1),
mask=lane_active,
other=0).to(tl.int64)
qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64)
qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64)
seqlen_run = (qe - qs).to(tl.int32)
# revise effective state_len for shorter sequences (same formula as original)
state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32)
@@ -289,9 +262,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
# spec decoding offset (optional)
# -------------------------
if IS_SPEC_DECODING:
conv_state_token_offset = (
tl.load(num_accepted_tokens_ptr + b, mask=lane_active,
other=1).to(tl.int64) - 1)
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1
shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode
else:
conv_state_token_offset = tl.full((), 0, tl.int64)
@@ -300,37 +271,37 @@ def _causal_conv1d_update_kernel_npu_tiled(
# -------------------------
# STEP 1: read initial history cols BEFORE state update (out==x safe)
# -------------------------
conv_states_base = (conv_state_ptr +
conv_states_input_coord * stride_conv_state_seq +
idx_feats * stride_conv_state_dim)
conv_states_base = (
conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim
)
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
# define history vectors as zeros then load conditionally
col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
col0 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col1 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col2 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col3 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col4 = tl.zeros((BLOCK_N,), dtype=tl.float16)
if KERNEL_WIDTH >= 2:
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 3:
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 4:
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 5:
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 6:
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
# -------------------------
# STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block)
@@ -340,29 +311,25 @@ def _causal_conv1d_update_kernel_npu_tiled(
# dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run]
# -------------------------
# output cache line
conv_states_offset = tl.load(conv_state_indices_ptr +
b * stride_state_indices +
current_last_index,
mask=lane_active,
other=0).to(tl.int64)
conv_states_offset = tl.load(
conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0
).to(tl.int64)
use_shift = (seqlen_run < state_len_run)
use_tail = (seqlen_run >= state_len_run)
use_shift = seqlen_run < state_len_run
use_tail = seqlen_run >= state_len_run
zero_i32 = tl.full((), 0, tl.int32)
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run),
zero_i32).to(tl.int32)
tail_start = tl.where(use_tail, (seqlen_run - state_len_run),
zero_i32).to(tl.int32)
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32)
tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32)
# base pointers
state_src_base = (conv_state_ptr +
conv_states_input_coord * stride_conv_state_seq +
conv_state_token_offset * stride_conv_state_tok +
idx_feats * stride_conv_state_dim)
state_dst_base = (conv_state_ptr +
conv_states_offset * stride_conv_state_seq +
idx_feats * stride_conv_state_dim)
state_src_base = (
conv_state_ptr
+ conv_states_input_coord * stride_conv_state_seq
+ conv_state_token_offset * stride_conv_state_tok
+ idx_feats * stride_conv_state_dim
)
state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim
x_base = x_ptr + x_offset + idx_feats * stride_x_dim
@@ -370,16 +337,16 @@ def _causal_conv1d_update_kernel_npu_tiled(
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK]
m_tok = use_shift & (dst_tok < keep_shift) & (
src_tok < state_len_run) & (dst_tok < state_len_run)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
conv_states_input_coord
< num_cache_lines) & (conv_states_offset < num_cache_lines)
m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run)
m = (
(lane_active & m_tok)[:, None]
& mask_w[None, :]
& (conv_states_input_coord < num_cache_lines)
& (conv_states_offset < num_cache_lines)
)
src_ptrs = state_src_base[
None, :] + src_tok[:, None] * stride_conv_state_tok
dst_ptrs = state_dst_base[
None, :] + dst_tok[:, None] * stride_conv_state_tok
src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
vals = tl.load(src_ptrs, mask=m, other=0.0)
tl.store(dst_ptrs, vals, mask=m)
@@ -387,14 +354,11 @@ def _causal_conv1d_update_kernel_npu_tiled(
for t0 in tl.static_range(0, seqlen, T_CHUNK):
x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK]
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok
< state_len_run)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
conv_states_offset < num_cache_lines)
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
dst_ptrs = state_dst_base[
None, :] + dst_tok[:, None] * stride_conv_state_tok
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
tl.store(dst_ptrs, x_vals, mask=m)
@@ -403,12 +367,10 @@ def _causal_conv1d_update_kernel_npu_tiled(
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK]
m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
conv_states_offset < num_cache_lines)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
dst_ptrs = state_dst_base[
None, :] + dst_tok[:, None] * stride_conv_state_tok
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
tl.store(dst_ptrs, x_vals, mask=m)
@@ -433,17 +395,13 @@ def _causal_conv1d_update_kernel_npu_tiled(
if KERNEL_WIDTH == 1:
# only x[t] * w0
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
matrix_w = w_col0
elif KERNEL_WIDTH == 2:
if j == 1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
@@ -451,9 +409,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
@@ -464,9 +420,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 5:
if j == 1:
matrix_w = w_col1
@@ -480,9 +434,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
elif j == 4:
matrix_w = w_col4
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 6:
if j == 1:
matrix_w = w_col1
@@ -499,9 +451,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
elif j == 5:
matrix_w = w_col5
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d,
mask=lane_active & mask_w,
other=0.0).to(tl.float16)
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N]
@@ -606,7 +556,7 @@ def causal_conv1d_update_npu(
x = x.unsqueeze(1)
if query_start_loc is None:
batch, seqlen, dim = x.shape
batch, seqlen, dim = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
@@ -614,14 +564,14 @@ def causal_conv1d_update_npu(
seqlen = max_query_len
width, _ = weight.shape
num_cache_lines, state_len_total,_ = conv_state.size()
num_cache_lines, state_len_total, _ = conv_state.size()
# overwrite-on-x strategy same as original
out = x
stride_w_width, stride_w_dim = weight.stride()
if query_start_loc is None:
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
stride_x_seq, stride_x_token, stride_x_dim = x.stride()
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
else:
stride_x_token, stride_x_dim = x.stride()
@@ -629,10 +579,8 @@ def causal_conv1d_update_npu(
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
)
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride()
stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0
# effective state_len exactly as original
if num_accepted_tokens is not None:
@@ -642,10 +590,10 @@ def causal_conv1d_update_npu(
np2_statelen = triton.next_power_of_2(eff_state_len)
# -------- tiling heuristic--------
#keep program count around ~[80..160]
# keep program count around ~[80..160]
# vector core 40
# TODO: use driver to get the vector core num
CORE_HINT = 40
CORE_HINT = 40
# channel tile: 512 when dim large (reduce tasks), else 256
block_n = 512 if dim >= 512 else 256
g = triton.cdiv(dim, block_n)
@@ -669,6 +617,7 @@ def causal_conv1d_update_npu(
triton.cdiv(batch, META["B_TILE"]),
triton.cdiv(dim, META["BLOCK_N"]),
)
_causal_conv1d_update_kernel_npu_tiled[grid](
x,
weight,

View File

@@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
draft_token_id1 = tl.get_element(draft_token_id, (pos,))
target_argmax1 = tl.get_element(target_argmax_id, (pos,))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
@@ -79,9 +79,7 @@ def bonus_renew(
num_tokens1,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
bonus_token_id)
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
@@ -106,17 +104,15 @@ def rejection_greedy_sample_triton(
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
is_greedy_mask = mask & (is_greedy != 0)
start_idx = tl.where(
offset == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
num_draft_tokens = end_idx - start_idx
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
num_tokens1 = tl.get_element(num_draft_tokens, (pos,))
rejected = False
start_idx1 = tl.get_element(start_idx, (pos, ))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
start_idx1 = tl.get_element(start_idx, (pos,))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
@@ -142,50 +138,44 @@ def rejection_greedy_sample_triton(
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vec_len
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
not_greedy_mask = is_greedy == 0
start_idxs = tl.where(
offsets == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
start_idx = tl.get_element(start_idxs, (req_i, ))
start_idx = tl.get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx +
pos)
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
@@ -195,17 +185,13 @@ def rejection_random_sample_kernel(
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr +
start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
pos, token_id)
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens,
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
@@ -225,8 +211,7 @@ def expand_kernel(
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
len_mask = offset < vec_len
start_idx = tl.where(offset == 0, 0,
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
num_tokens = end_idx - start_idx
@@ -234,13 +219,11 @@ def expand_kernel(
src_val = tl.where(src_val == replace_from, replace_to, src_val)
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i, ))
start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
num_tokens1 = tl.get_element(num_tokens, (i,))
start_idx1 = tl.get_element(start_idx, (i,))
src_val1 = tl.get_element(src_val, (i,))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
mask=offset1 < num_tokens1)
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
@triton.jit
@@ -257,8 +240,7 @@ def sample_recovered_tokens_kernel(
SUB_BLOCK: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
@@ -272,27 +254,25 @@ def sample_recovered_tokens_kernel(
global_max_p = -1.0
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0)
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
max_p = tl.get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
@@ -300,25 +280,24 @@ def sample_recovered_tokens_kernel(
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
max_p = tl.get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
@@ -327,21 +306,25 @@ def sample_recovered_tokens_kernel(
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)
tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob)
def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
cu_num_draft_tokens, draft_token_ids,
target_argmax, bonus_token_ids,
is_greedy, max_spec_len, grid,
block_size):
def rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
grid,
block_size,
):
vec_len = output_token_ids.shape[0]
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and is_greedy is None:
rejection_greedy_sample_spec_len_1_triton[(grid, )](
if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and is_greedy is None:
rejection_greedy_sample_spec_len_1_triton[(grid,)](
output_token_ids,
draft_token_ids,
target_argmax,
@@ -350,7 +333,7 @@ def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
BLOCK_SIZE=block_size,
)
else:
rejection_greedy_sample_triton[(grid, )](
rejection_greedy_sample_triton[(grid,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -363,12 +346,11 @@ def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
)
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
replace_to, max_num_tokens):
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens):
vec_len = batch_size
grid, block_size = cal_grid_and_block_size(batch_size)
expand_kernel[(grid, )](
expand_kernel[(grid,)](
expanded_x,
x,
cu_num_tokens,
@@ -382,56 +364,50 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_block_verify_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vec_len
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
not_greedy_mask = is_greedy == 0
start_idxs = tl.where(
offsets == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = tl.get_element(start_idxs, (req_i, ))
start_idx = tl.get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
uniform_prob = uniform_prob * tmp_uniform_prob
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
@@ -443,19 +419,14 @@ def rejection_random_sample_block_verify_kernel(
if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
pos, token_id)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
if rejected:
recovered_token_id = tl.load(recovered_token_ids_ptr +
start_idx +
last_accepted_token_pos + 1)
recovered_token_id = tl.load(recovered_token_ids_ptr + start_idx + last_accepted_token_pos + 1)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
last_accepted_token_pos + 1, recovered_token_id)
output_token_ids_ptr + req_idx * (max_spec_len + 1) + last_accepted_token_pos + 1,
recovered_token_id,
)
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id)

View File

@@ -14,9 +14,10 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from vllm.triton_utils import tl, triton
import torch
from typing import Tuple
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@@ -48,16 +49,16 @@ def _triton_rope(
This triton kernel applies rotary embedding on q and k.
It supports rope_dim != head_dim scenario.
It supports both neox style and non-neox style rope computation.
Input tensor layout assumptions:
q size: (num_tokens, num_q_heads, head_dim)
q stride: (num_q_heads * head_dim, head_dim, 1)
k size: (num_tokens, num_kv_heads, head_dim)
k stride: (num_kv_heads * head_dim, head_dim, 1)
cos/sin size: (num_tokens, rope_dim/2)
cos/sin stride: (rope_dim/2, 1)
Different compute pattern of IS_NEOX_STYLE:
if IS_NEOX_STYLE:
@@ -88,10 +89,8 @@ def _triton_rope(
cos_offsets = tl.arange(0, pad_rope_dim // 2)
cos_mask = cos_offsets < (rope_dim // 2)
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask,
other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask,
other=0).to(tl.float32)
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
# ####################################################################
# Load the left and right half of q and k for the current
@@ -99,28 +98,20 @@ def _triton_rope(
# ####################################################################
# left half of the head
if IS_NEOX_STYLE:
first_half_q_offsets = tl.arange(
0, pad_n_qh)[:, None] * hd + tl.arange(
0, pad_rope_dim // 2)[None, :]
first_half_k_offsets = tl.arange(
0, pad_n_kh)[:, None] * hd + tl.arange(
0, pad_rope_dim // 2)[None, :]
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
else:
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets,
mask=first_q_mask,
other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets,
mask=first_k_mask,
other=0).to(sin_row.dtype)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
)
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
if IS_NEOX_STYLE:
@@ -131,41 +122,29 @@ def _triton_rope(
second_half_k_offsets = first_half_k_offsets + 1
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets,
mask=second_q_mask,
other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets,
mask=second_k_mask,
other=0).to(sin_row.dtype)
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_start_ptr + first_half_q_offsets,
new_q_tile_1,
mask=first_q_mask)
tl.store(q_start_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_start_ptr + second_half_q_offsets,
new_q_tile_2,
mask=second_q_mask)
tl.store(q_start_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_start_ptr + first_half_k_offsets,
new_k_tile_1,
mask=first_k_mask)
tl.store(k_start_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_start_ptr + second_half_k_offsets,
new_k_tile_2,
mask=second_k_mask)
tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
def rope_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
is_neox_style: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
is_neox_style: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
@@ -187,7 +166,7 @@ def rope_forward_triton(
num_vectorcore = get_vectorcore_num()
n_row = min(num_tokens, num_vectorcore)
_triton_rope[(n_row, )](
_triton_rope[(n_row,)](
q,
q.stride(0),
k,

View File

@@ -49,20 +49,15 @@ def prepare_inputs_padded_kernel(
other=0,
)
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev,
cu_draft_curr)
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, cu_draft_curr)
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets,
mask=mask)
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, mask=mask)
num_rejected = num_draft_tokens + 1 - valid_count
num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1,
mask=mask) - 1
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, mask=mask) - 1
index_to_sample = q_last_tok_idx - num_rejected
tl.store(token_indices_to_sample_ptr + offsets,
index_to_sample,
mask=mask)
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any
import torch
from vllm.triton_utils import HAS_TRITON, triton
@@ -10,9 +10,9 @@ _NUM_VECTORCORE = -1
def init_device_properties_triton():
global _NUM_AICORE, _NUM_VECTORCORE
if _NUM_AICORE == -1 and HAS_TRITON:
device_properties: Dict[str, Any] = (
triton.runtime.driver.active.utils.get_device_properties(
torch.npu.current_device()))
device_properties: dict[str, Any] = triton.runtime.driver.active.utils.get_device_properties(
torch.npu.current_device()
)
_NUM_AICORE = device_properties.get("num_aicore", -1)
_NUM_VECTORCORE = device_properties.get("num_vectorcore", -1)
assert _NUM_AICORE > 0 and _NUM_VECTORCORE > 0, "Failed to detect device properties."