### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -26,8 +26,7 @@ def _swiglu_quant_kernel(
|
||||
else:
|
||||
gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN)
|
||||
gl_mask = gl_offsets < NUM_EXPERTS
|
||||
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask,
|
||||
other=0).to(tl.int32)
|
||||
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32)
|
||||
total_rows = tl.sum(group_list)
|
||||
|
||||
block_size = (total_rows - 1) // NUM_CORES + 1
|
||||
@@ -41,14 +40,8 @@ def _swiglu_quant_kernel(
|
||||
# swiglu
|
||||
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
|
||||
cur_x = tl.load(x_ptr + x_offsets)
|
||||
x1 = tl.extract_slice(cur_x,
|
||||
offsets=(0, ),
|
||||
sizes=(HALF_COLS, ),
|
||||
strides=(1, ))
|
||||
x2 = tl.extract_slice(cur_x,
|
||||
offsets=(HALF_COLS, ),
|
||||
sizes=(HALF_COLS, ),
|
||||
strides=(1, ))
|
||||
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
|
||||
out = x1 * tl.sigmoid(x1) * x2
|
||||
|
||||
# quant
|
||||
@@ -57,20 +50,13 @@ def _swiglu_quant_kernel(
|
||||
# store scale
|
||||
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
|
||||
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
|
||||
tmp_out = tl.extract_slice(out,
|
||||
offsets=(col_blk_idx, ),
|
||||
sizes=(COL_BLOCK_SIZE, ),
|
||||
strides=(1, ))
|
||||
tmp_out = (tmp_out.to(tl.float32) / scale).to(
|
||||
x_ptr.dtype.element_ty)
|
||||
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
|
||||
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
|
||||
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
|
||||
|
||||
o_offsets = (row_idx * HALF_COLS + col_blk_idx +
|
||||
tl.arange(0, COL_BLOCK_SIZE))
|
||||
o_offsets = row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)
|
||||
mask = (col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)) < HALF_COLS
|
||||
tl.store(out_ptr + o_offsets,
|
||||
tmp_out.to(out_ptr.dtype.element_ty),
|
||||
mask=mask)
|
||||
tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty), mask=mask)
|
||||
else:
|
||||
# store out
|
||||
o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS)
|
||||
@@ -80,12 +66,11 @@ def _swiglu_quant_kernel(
|
||||
def swiglu_quant(x, group_list, group_list_type, need_quant=True):
|
||||
# group_list_type must be 0 cusum or 1 count
|
||||
if group_list_type not in [0, 1]:
|
||||
raise ValueError(
|
||||
f"group_list_type must be 0 or 1, but got {group_list_type}")
|
||||
raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}")
|
||||
s, h = x.shape
|
||||
out_dtype = torch.int8 if need_quant else x.dtype
|
||||
out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device)
|
||||
scale = torch.empty((s, ), dtype=torch.float32, device=x.device)
|
||||
scale = torch.empty((s,), dtype=torch.float32, device=x.device)
|
||||
num_experts = group_list.shape[0]
|
||||
# ub must be 32-byte aligned on npu
|
||||
if group_list.dtype == torch.int64:
|
||||
@@ -93,12 +78,10 @@ def swiglu_quant(x, group_list, group_list_type, need_quant=True):
|
||||
elif group_list.dtype == torch.int32:
|
||||
num_experts_algin = (num_experts + 15) // 16 * 16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}"
|
||||
)
|
||||
raise ValueError(f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}")
|
||||
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
_swiglu_quant_kernel[(num_vectorcore, )](
|
||||
_swiglu_quant_kernel[(num_vectorcore,)](
|
||||
x,
|
||||
group_list,
|
||||
out,
|
||||
|
||||
@@ -67,11 +67,9 @@ def matmul_bias_persistent_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
k_start = k * BLOCK_K
|
||||
# Calculate pointer offsets for x (row-major)
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] +
|
||||
k_start) * stride_xk
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
|
||||
# Calculate pointer offsets for y (row-major)
|
||||
y_ptrs = y_ptr + (rk[:, None] +
|
||||
k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
|
||||
# Create masks to prevent out-of-bounds access
|
||||
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
|
||||
@@ -89,14 +87,12 @@ def matmul_bias_persistent_kernel(
|
||||
# Load bias values (broadcast to all rows)
|
||||
bias_ptrs = bias_ptr + rn * stride_bias
|
||||
bias_mask = rn < N
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask,
|
||||
other=0.0).to(tl.float32)
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
|
||||
# Add bias to accumulator (automatic broadcasting)
|
||||
acc += bias_vals[None, :]
|
||||
|
||||
# Calculate output pointer positions
|
||||
out_ptrs = output_ptr + rm[:,
|
||||
None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
|
||||
# Store result to global memory
|
||||
@@ -106,28 +102,28 @@ def matmul_bias_persistent_kernel(
|
||||
def matmul_persistent(x, y, bias=None):
|
||||
"""
|
||||
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
|
||||
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [K, N]
|
||||
bias: torch.Tensor, shape [N] or None
|
||||
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
|
||||
M, K = x.shape
|
||||
_, N = y.shape
|
||||
# Validate bias shape (if not None)
|
||||
if bias is not None:
|
||||
assert bias.dim() == 1, "bias must be a 1D tensor"
|
||||
assert y.shape[1] == bias.shape[
|
||||
0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
assert y.shape[1] == bias.shape[0], (
|
||||
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
)
|
||||
|
||||
# Allocate output tensor (same data type as x)
|
||||
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
|
||||
@@ -176,24 +172,24 @@ def matmul_persistent(x, y, bias=None):
|
||||
|
||||
@triton.jit
|
||||
def linear_persistent_kernel(
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
):
|
||||
# Get current program's 1D index (1D grid)
|
||||
pid = tl.program_id(0)
|
||||
@@ -226,18 +222,12 @@ def linear_persistent_kernel(
|
||||
k_mask = k_indices < K
|
||||
|
||||
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[
|
||||
None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs,
|
||||
mask=m_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
|
||||
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[
|
||||
None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs,
|
||||
mask=n_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
|
||||
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
|
||||
b_vals_transposed = tl.trans(b_vals)
|
||||
@@ -246,8 +236,7 @@ def linear_persistent_kernel(
|
||||
product = tl.dot(a_vals, b_vals_transposed)
|
||||
acc += product
|
||||
# Store result to output tensor c
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[
|
||||
None, :] * stride_cn
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
|
||||
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
|
||||
|
||||
|
||||
@@ -255,19 +244,18 @@ def linear_persistent(x, y):
|
||||
"""
|
||||
Implement matrix multiplication with Triton: x @ y^T
|
||||
Uses a fixed-size 1D grid
|
||||
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [N, K]
|
||||
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
|
||||
M, K = x.shape
|
||||
N, _ = y.shape
|
||||
@@ -283,9 +271,8 @@ def linear_persistent(x, y):
|
||||
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
# Set fixed 1D grid size
|
||||
grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size, )
|
||||
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size,)
|
||||
|
||||
# Launch kernel
|
||||
linear_persistent_kernel[grid](
|
||||
@@ -330,8 +317,7 @@ def bmm_batch_invariant(a, b, *, out=None):
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
@@ -392,7 +378,8 @@ def matmul_batch_invariant(a, b, *, out=None):
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def linear_batch_invariant(input_, weight, bias=None):
|
||||
|
||||
@@ -86,8 +86,7 @@ def mean_dim(
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert -input_.ndim <= dim < input_.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions")
|
||||
assert -input_.ndim <= dim < input_.ndim, f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions"
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
@@ -123,7 +122,7 @@ def mean_dim(
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
output_shape = shape[:dim] + shape[dim + 1 :]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input_.device)
|
||||
@@ -135,7 +134,7 @@ def mean_dim(
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
grid = (M * K,)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
@@ -165,13 +164,10 @@ def mean_batch_invariant(
|
||||
if len(dim) == 1:
|
||||
return mean_dim(input_, dim[0], keepdim=keepdim)
|
||||
else:
|
||||
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32
|
||||
}, ("only float types supported for now")
|
||||
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32}, "only float types supported for now"
|
||||
if len(dim) == 0:
|
||||
dim = list(range(input_.ndim))
|
||||
n_elems = 1
|
||||
for d in dim:
|
||||
n_elems *= input_.shape[d]
|
||||
return torch.sum(input_, dim=dim, keepdim=keepdim,
|
||||
dtype=torch.float32).to(dtype
|
||||
or input_.dtype) / n_elems
|
||||
return torch.sum(input_, dim=dim, keepdim=keepdim, dtype=torch.float32).to(dtype or input_.dtype) / n_elems
|
||||
|
||||
@@ -100,8 +100,8 @@ def rms_norm(
|
||||
"""
|
||||
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||
assert input_.shape[-1] == weight.shape[0], (
|
||||
f"Input last dimension ({input_.shape[-1]}) must match "
|
||||
f"weight dimension ({weight.shape[0]})")
|
||||
f"Input last dimension ({input_.shape[-1]}) must match weight dimension ({weight.shape[0]})"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input_.shape
|
||||
@@ -113,10 +113,9 @@ def rms_norm(
|
||||
|
||||
output = torch.empty_like(input_2d, dtype=input_.dtype)
|
||||
BLOCK_SIZE = 1024
|
||||
max_grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"]
|
||||
max_grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"]
|
||||
|
||||
grid = (min(n_rows, max_grid_size), )
|
||||
grid = (min(n_rows, max_grid_size),)
|
||||
|
||||
_rms_norm_kernel[grid](
|
||||
input_2d,
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@@ -25,22 +24,20 @@ from .utils import input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=torch.float32)
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
@@ -75,20 +72,21 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
@@ -110,17 +108,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
@@ -186,41 +186,39 @@ def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
stacklevel=2,
|
||||
)
|
||||
q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g))
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
return o, final_state
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
return o, final_state
|
||||
|
||||
@@ -8,23 +8,24 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp
|
||||
|
||||
_CONDITIONS = ("seq7168", )
|
||||
_CONDITIONS = ("seq7168",)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
@@ -85,28 +86,20 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
if USE_INITIAL_STATE:
|
||||
h0_ptr = h0 + i_nh * K * V
|
||||
ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1
|
||||
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1,
|
||||
other=0.0).to(tl.float32)
|
||||
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, other=0.0).to(tl.float32)
|
||||
|
||||
ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1
|
||||
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2,
|
||||
other=0.0).to(tl.float32)
|
||||
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, other=0.0).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
h_base = h + (boh + i_t) * H * K * V + i_h * K * V
|
||||
|
||||
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_h1_bv1,
|
||||
b_h1_bv1.to(p_h1_bv1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
|
||||
tl.store(p_h1_bv1, b_h1_bv1.to(p_h1_bv1.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_h1_bv2,
|
||||
b_h1_bv2.to(p_h1_bv2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
|
||||
tl.store(p_h1_bv2, b_h1_bv2.to(p_h1_bv2.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None]
|
||||
offs_k_wv = tl.arange(0, 128)[None, :]
|
||||
@@ -117,8 +110,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
b_w = tl.load(ptr_w, mask=mask_w, other=0.0)
|
||||
|
||||
k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K
|
||||
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT),
|
||||
(128, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), (128, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
v_new_base = v_new + bos * H * V + i_h * V
|
||||
@@ -144,12 +136,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
|
||||
(i_t * BT, v_start1), (BT, 64),
|
||||
(1, 0))
|
||||
tl.store(p_v_new1,
|
||||
b_v_new1.to(p_v_new1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start1), (BT, 64), (1, 0))
|
||||
tl.store(p_v_new1, b_v_new1.to(p_v_new1.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_v_new1 = b_v_new1 * b_g[:, None]
|
||||
@@ -165,12 +153,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
|
||||
(i_t * BT, v_start2), (BT, 64),
|
||||
(1, 0))
|
||||
tl.store(p_v_new2,
|
||||
b_v_new2.to(p_v_new2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start2), (BT, 64), (1, 0))
|
||||
tl.store(p_v_new2, b_v_new2.to(p_v_new2.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_v_new2 = b_v_new2 * b_g[:, None]
|
||||
@@ -183,29 +167,23 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
if STORE_FINAL_STATE:
|
||||
ht_ptr = ht + i_nh * K * V
|
||||
|
||||
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv1,
|
||||
b_h1_bv1.to(p_ht1_bv1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv1, b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv2,
|
||||
b_h1_bv2.to(p_ht1_bv2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv2, b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
g: torch.Tensor | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
@@ -213,8 +191,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None else None)
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
@@ -227,8 +204,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
|
||||
if output_final_state else None)
|
||||
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -17,11 +16,13 @@ from vllm.triton_utils import tl, triton
|
||||
from .utils import prepare_chunk_offsets, safe_exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
@@ -48,8 +49,7 @@ def chunk_fwd_kernel_o(
|
||||
T_max = T
|
||||
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
|
||||
@@ -71,12 +71,9 @@ def chunk_fwd_kernel_o(
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1),
|
||||
(i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K),
|
||||
(i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1),
|
||||
(i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
@@ -102,10 +99,8 @@ def chunk_fwd_kernel_o(
|
||||
m_A = o_i[:, None] >= o_i[None, :]
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# to fix mma -> mma layout conversion
|
||||
@@ -119,9 +114,9 @@ def chunk_fwd_o(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
g: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
@@ -129,7 +124,7 @@ def chunk_fwd_o(
|
||||
BT = chunk_size
|
||||
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
if cu_seqlens is None:
|
||||
@@ -141,7 +136,7 @@ def chunk_fwd_o(
|
||||
)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), N * H)
|
||||
return (triton.cdiv(V, meta["BV"]), N * H)
|
||||
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
chunk_fwd_kernel_o[grid](
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -16,11 +15,13 @@ from vllm.triton_utils import tl, triton
|
||||
from .utils import prepare_chunk_indices, safe_exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
'USE_G': lambda args: args['g_cumsum'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
@triton.heuristics(
|
||||
{
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"USE_G": lambda args: args["g_cumsum"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta, # [H, B, T]
|
||||
@@ -44,10 +45,11 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
for i_bh in range(B * H):
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t_i * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
@@ -55,39 +57,37 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
o_t = tl.arange(0, BT)
|
||||
o_t_fp32 = o_t.to(tl.float32)
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_A += tl.dot(b_k, tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ),
|
||||
(1, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A *= safe_exp(b_g_diff)
|
||||
|
||||
b_A *= b_beta[:, None]
|
||||
b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
@@ -117,8 +117,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
BT = chunk_size
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
|
||||
if cu_seqlens is not None else None)
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
chunk_indices = chunk_indices.npu()
|
||||
cu_seqlens = cu_seqlens.npu()
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -16,11 +15,10 @@ from vllm.triton_utils import tl, triton
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'HAS_SCALE': lambda args: args['scale'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
@triton.heuristics(
|
||||
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
@@ -41,20 +39,19 @@ def chunk_local_cumsum_scalar_kernel(
|
||||
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
|
||||
|
||||
if IS_VARLEN:
|
||||
i_s, i_block = tl.load(chunk_indices + i_block * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_s).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
|
||||
i_s, i_block = (
|
||||
tl.load(chunk_indices + i_block * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1),
|
||||
(0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1),
|
||||
(0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32)
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
|
||||
b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE))
|
||||
b_s = tl.trans(b_s, (2, 0, 1))
|
||||
b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
|
||||
@@ -63,11 +60,9 @@ def chunk_local_cumsum_scalar_kernel(
|
||||
b_o = tl.trans(b_o, (2, 0, 1))
|
||||
b_o = tl.reshape(b_o, (H, BLOCK_T))
|
||||
else:
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1),
|
||||
(i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1),
|
||||
(i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32)
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
|
||||
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
|
||||
b_s = tl.trans(b_s, (1, 0, 2))
|
||||
b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
|
||||
@@ -76,7 +71,7 @@ def chunk_local_cumsum_scalar_kernel(
|
||||
b_o = tl.trans(b_o, (1, 0, 2))
|
||||
b_o = tl.reshape(b_o, (BLOCK_T, H))
|
||||
|
||||
tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, ))
|
||||
tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0,))
|
||||
return
|
||||
|
||||
|
||||
@@ -85,61 +80,64 @@ def chunk_local_cumsum_scalar(
|
||||
chunk_size,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.Tensor] = torch.float,
|
||||
output_dtype: torch.Tensor | None = torch.float,
|
||||
):
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
|
||||
OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
|
||||
block_indices = prepare_chunk_indices(
|
||||
cu_seqlens,
|
||||
chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
|
||||
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(
|
||||
T, OPTIM_BLOCK_SIZE)
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
|
||||
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (num_blocks, B)
|
||||
chunk_local_cumsum_scalar_kernel[grid](s=g_org,
|
||||
o=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=block_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BLOCK_T=OPTIM_BLOCK_SIZE,
|
||||
CHUNK_SIZE=chunk_size,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
num_warps=8,
|
||||
num_stages=3)
|
||||
chunk_local_cumsum_scalar_kernel[grid](
|
||||
s=g_org,
|
||||
o=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=block_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BLOCK_T=OPTIM_BLOCK_SIZE,
|
||||
CHUNK_SIZE=chunk_size,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum(g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs) -> torch.Tensor:
|
||||
def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[
|
||||
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(g=g,
|
||||
chunk_size=chunk_size,
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype)
|
||||
return chunk_local_cumsum_scalar(
|
||||
g=g,
|
||||
chunk_size=chunk_size,
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape {g.shape}, "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise")
|
||||
raise ValueError(
|
||||
f"Unsupported input shape {g.shape}, "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise"
|
||||
)
|
||||
|
||||
@@ -31,29 +31,30 @@ def fused_qkvzba_split_reshape_cat_kernel(
|
||||
BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
|
||||
QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
|
||||
q_end: tl.constexpr = HEAD_QK
|
||||
blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
|
||||
i_qk * QKVZ_DIM_T + tl.arange(0, q_end))
|
||||
blk_q_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(0, q_end)
|
||||
k_end: tl.constexpr = q_end + HEAD_QK
|
||||
blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
|
||||
i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end))
|
||||
blk_k_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end)
|
||||
v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
|
||||
blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
|
||||
i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end))
|
||||
blk_v_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end)
|
||||
z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
|
||||
blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
|
||||
i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end))
|
||||
blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
|
||||
i_qk * HEAD_QK + tl.arange(0, HEAD_QK))
|
||||
blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
|
||||
NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK +
|
||||
tl.arange(0, HEAD_QK))
|
||||
blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
|
||||
NUM_HEADS_QK * HEAD_QK * 2 +
|
||||
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
|
||||
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
|
||||
blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V +
|
||||
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
|
||||
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
|
||||
blk_z_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end)
|
||||
blk_q_st_ptr = mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)
|
||||
blk_k_st_ptr = (
|
||||
mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)
|
||||
)
|
||||
blk_v_st_ptr = (
|
||||
mixed_qkv
|
||||
+ i_bs * NUM_HEADS_QK * QKV_DIM_T
|
||||
+ NUM_HEADS_QK * HEAD_QK * 2
|
||||
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
|
||||
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
|
||||
)
|
||||
blk_z_st_ptr = (
|
||||
z
|
||||
+ i_bs * NUM_HEADS_V * HEAD_V
|
||||
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
|
||||
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
|
||||
)
|
||||
tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
|
||||
tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
|
||||
tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
|
||||
@@ -66,8 +67,7 @@ def fused_qkvzba_split_reshape_cat_kernel(
|
||||
tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
|
||||
for i in tl.static_range(b_end, a_end):
|
||||
blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
|
||||
blk_a_st_ptr = (a + i_bs * NUM_HEADS_V +
|
||||
i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end))
|
||||
blk_a_st_ptr = a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
|
||||
tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
|
||||
|
||||
|
||||
|
||||
@@ -15,8 +15,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
|
||||
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
|
||||
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
|
||||
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
|
||||
@@ -24,8 +23,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
|
||||
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
|
||||
xs = tl.load(X + (rindex + N * row_idx), mask=xmask,
|
||||
other=0.0).to(tl.float32)
|
||||
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
|
||||
square = xs * xs
|
||||
square_sum = tl.sum(square, 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
@@ -33,9 +31,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: torch.dtype | None = None):
|
||||
def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
# allocate output
|
||||
@@ -56,7 +52,7 @@ def l2norm_fwd(x: torch.Tensor,
|
||||
num_core = get_vectorcore_num()
|
||||
main_bs = triton.cdiv(T, num_core)
|
||||
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
|
||||
grid = (num_core, )
|
||||
grid = (num_core,)
|
||||
l2norm_fwd_kernel2_loop[grid](
|
||||
X=x,
|
||||
Y=y,
|
||||
|
||||
@@ -12,10 +12,12 @@ from vllm.triton_utils import tl, triton
|
||||
MAX_CORES = 65535
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
@@ -49,13 +51,10 @@ def layer_norm_fwd_kernel(
|
||||
n_iters = n_iters + 1
|
||||
|
||||
for i in tl.range(n_iters):
|
||||
X_base = X + (i * BLOCK_ROWS *
|
||||
stride_x_row) + row * stride_x_row + group * N
|
||||
Y_base = Y + (i * BLOCK_ROWS *
|
||||
stride_y_row) + row * stride_y_row + group * N
|
||||
X_base = X + (i * BLOCK_ROWS * stride_x_row) + row * stride_x_row + group * N
|
||||
Y_base = Y + (i * BLOCK_ROWS * stride_y_row) + row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z_base = Z + (i * BLOCK_ROWS *
|
||||
stride_z_row) + row * stride_z_row + group * N
|
||||
Z_base = Z + (i * BLOCK_ROWS * stride_z_row) + row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean_base = Mean + (i * BLOCK_ROWS) + group * M
|
||||
Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
|
||||
@@ -64,17 +63,17 @@ def layer_norm_fwd_kernel(
|
||||
B_base = B + group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.load(X_base + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean_base + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd_base + row, rstd)
|
||||
@@ -112,26 +111,24 @@ def _layer_norm_fwd(
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm else None)
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
|
||||
@@ -160,7 +157,6 @@ def _layer_norm_fwd(
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
|
||||
@@ -14,7 +14,7 @@ import os
|
||||
import torch
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
@@ -31,17 +31,15 @@ else:
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_INITIAL_STATE':
|
||||
lambda args: args['h0'] is not None,
|
||||
'IS_VARLEN':
|
||||
lambda args: args['cu_seqlens'] is not None,
|
||||
"IS_CONTINUOUS_BATCHING":
|
||||
lambda args: args['ssm_state_indices'] is not None,
|
||||
"IS_SPEC_DECODING":
|
||||
lambda args: args['num_accepted_tokens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['N', 'T'])
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
|
||||
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["N", "T"])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
@@ -70,8 +68,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.
|
||||
constexpr, # whether beta is headwise vector or scalar,
|
||||
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
@@ -82,8 +79,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
@@ -108,8 +104,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_init_state_token
|
||||
p_h0 = (
|
||||
h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token
|
||||
)
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
@@ -164,18 +161,21 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_final_state_token
|
||||
p_ht = (
|
||||
ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token
|
||||
)
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
A_log,
|
||||
@@ -245,8 +245,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
idx = tl.load(h0_indices + i_n)
|
||||
# if idx >= 0:
|
||||
tmp0 = tl.where(idx < 0, 0, idx)
|
||||
p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V +
|
||||
o_k[:, None] * V + o_v[None, :])
|
||||
p_h0 = h0_source + tmp0 * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
temp2 = tl.zeros_like(temp1)
|
||||
value0 = tl.where(idx < 0, temp2, temp1)
|
||||
@@ -314,8 +313,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
if USE_INITIAL_STATE:
|
||||
idx = tl.load(h0_indices + i_n)
|
||||
if idx >= 0:
|
||||
p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V +
|
||||
o_k[:, None] * V + o_v[None, :])
|
||||
p_h0 = h0_source + idx * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
@@ -350,7 +348,7 @@ def fused_sigmoid_gating_delta_rule_update(
|
||||
num_warps = 1
|
||||
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -66,9 +65,13 @@ def solve_tril_16x16_kernel(
|
||||
offs_cols_in_block = tl.arange(0, 16)
|
||||
|
||||
# 2 Calculate the pointer of each element
|
||||
ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o +
|
||||
offs_rows_in_block[:, None] * H * BT +
|
||||
offs_cols_in_block[None, :])
|
||||
ptr_A_subrec16 = (
|
||||
A
|
||||
+ row_start_o * H * BT
|
||||
+ col_start_o
|
||||
+ offs_rows_in_block[:, None] * H * BT
|
||||
+ offs_cols_in_block[None, :]
|
||||
)
|
||||
|
||||
# 3 Create a mask to prevent out-of-bounds access
|
||||
global_rows = row_start_o + offs_rows_in_block[:, None]
|
||||
@@ -76,14 +79,14 @@ def solve_tril_16x16_kernel(
|
||||
load_mask = (global_rows < T) & (global_cols < BT)
|
||||
|
||||
# 4 Use mask to safely load data
|
||||
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask,
|
||||
other=0.0).to(tl.float32)
|
||||
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
|
||||
b_A = tl.insert_slice(
|
||||
ful=b_A,
|
||||
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
|
||||
offsets=[blkid, 0, 0],
|
||||
sizes=[1, 16, 16],
|
||||
strides=[1, 1, 1])
|
||||
strides=[1, 1, 1],
|
||||
)
|
||||
|
||||
local_ori_A = tl.trans(b_A, (1, 0, 2))
|
||||
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
|
||||
@@ -97,9 +100,7 @@ def solve_tril_16x16_kernel(
|
||||
|
||||
# for loop to update N_BLOCKS row vector
|
||||
for i in range(1, 16):
|
||||
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0),
|
||||
(1, 16 * N_BLOCKS),
|
||||
(16 * N_BLOCKS, 1))
|
||||
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
|
||||
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
|
||||
|
||||
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
|
||||
@@ -107,34 +108,27 @@ def solve_tril_16x16_kernel(
|
||||
b_a = b_a + dot_product
|
||||
|
||||
b_a_new_expanded = b_a[:, None, :]
|
||||
b_A = tl.insert_slice(ful=b_A,
|
||||
sub=b_a_new_expanded,
|
||||
offsets=[0, i, 0],
|
||||
sizes=[N_BLOCKS, 1, 16],
|
||||
strides=[1, 1, 1])
|
||||
b_A = tl.insert_slice(
|
||||
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
|
||||
)
|
||||
|
||||
on_diagonal = (rows == cols)
|
||||
on_diagonal = rows == cols
|
||||
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
|
||||
|
||||
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0),
|
||||
(N_BLOCKS * 16, 16), (1, 0))
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0))
|
||||
|
||||
# 1 Create in-block offset
|
||||
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
|
||||
offs_cols_to_store = tl.arange(0, 16)
|
||||
|
||||
# 2 Calculate the pointer of each element
|
||||
p_Ai = (Ad + base_t * H * 16 + 0 +
|
||||
offs_rows_to_store[:, None] * H * 16 +
|
||||
offs_cols_to_store[None, :])
|
||||
p_Ai = Ad + base_t * H * 16 + 0 + offs_rows_to_store[:, None] * H * 16 + offs_cols_to_store[None, :]
|
||||
# 3 Create a mask to prevent out-of-bounds access, only check rows
|
||||
global_store_rows = base_t + offs_rows_to_store[:, None]
|
||||
store_mask = global_store_rows < T
|
||||
# 4 use mask to save data safely
|
||||
tl.store(p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=store_mask)
|
||||
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=store_mask)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@@ -169,18 +163,12 @@ def merge_16x16_to_32x32_inverse_kernel(
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 32
|
||||
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
@@ -313,26 +301,20 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
offs_n = tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
tl.store(ptr_Ai,
|
||||
Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=mask_store)
|
||||
tl.store(ptr_Ai, Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
|
||||
|
||||
# store Ai_22_32 to (i_t * 64 + 32, 32)
|
||||
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
||||
offs_n = 32 + tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
tl.store(ptr_Ai,
|
||||
Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=mask_store)
|
||||
tl.store(ptr_Ai, Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
|
||||
|
||||
# store Ai_21_32 to (i_t * 64 + 32, 32)
|
||||
offs_n = tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
tl.store(ptr_Ai,
|
||||
Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=mask_store)
|
||||
tl.store(ptr_Ai, Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
|
||||
|
||||
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
|
||||
offs_m = i_t * 64 + tl.arange(0, 32)
|
||||
@@ -345,7 +327,7 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -367,19 +349,12 @@ def solve_tril(
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
Ad = torch.empty(B,
|
||||
T,
|
||||
H,
|
||||
16,
|
||||
device=A.device,
|
||||
dtype=torch.float if BT != 16 else output_dtype)
|
||||
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
|
||||
|
||||
LARGE_BLOCK_T = 608 * 2
|
||||
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
|
||||
if cu_seqlens is not None else None)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(
|
||||
T, LARGE_BLOCK_T)
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
|
||||
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
@@ -398,10 +373,8 @@ def solve_tril(
|
||||
return Ad
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = (merge_16x16_to_32x32_inverse_kernel
|
||||
if BT == 32 else merge_16x16_to_64x64_inverse_kernel)
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
|
||||
if cu_seqlens is not None else None)
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
merge_fn[NT, B * H](
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -19,38 +19,24 @@ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
|
||||
1).to(cu_seqlens)
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([
|
||||
cu_seqlens.new_tensor([0]),
|
||||
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
|
||||
]).cumsum(-1)
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
|
||||
|
||||
|
||||
def input_guard(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else
|
||||
i.contiguous() for i in args)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
|
||||
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -17,23 +16,39 @@ from vllm.triton_utils import tl, triton
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, Hg: tl.constexpr,
|
||||
K: tl.constexpr, V: tl.constexpr,
|
||||
BT: tl.constexpr, BK: tl.constexpr,
|
||||
BV: tl.constexpr, IS_VARLEN: tl.constexpr):
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_w_u_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
T_max = T
|
||||
i_t_o = tl.program_id(0)
|
||||
|
||||
for i_bh in range(H):
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t_o * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t_o * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
@@ -44,7 +59,7 @@ def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
|
||||
|
||||
offs_t_2d = global_offs_t[:, None]
|
||||
offs_bt = tl.arange(0, BT)[None, :]
|
||||
ptr_A = (A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1)
|
||||
ptr_A = A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1
|
||||
mask_A = mask_t[:, None]
|
||||
b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
||||
|
||||
@@ -58,29 +73,25 @@ def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
|
||||
offs_v = i_v * BV + tl.arange(0, BV)[None, :]
|
||||
mask_v = (mask_t[:, None]) & (offs_v < V)
|
||||
|
||||
ptr_v = (v + (bos * H + i_h) * V + offs_t_2d * (H * V) +
|
||||
offs_v * 1)
|
||||
ptr_v = v + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
|
||||
b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32)
|
||||
|
||||
b_vb = (b_v * b_beta[:, None])
|
||||
b_vb = b_v * b_beta[:, None]
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
|
||||
ptr_u = (u + (bos * H + i_h) * V + offs_t_2d * (H * V) +
|
||||
offs_v * 1)
|
||||
ptr_u = u + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
|
||||
tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
offs_k = i_k * BK + tl.arange(0, BK)[None, :]
|
||||
mask_k = (mask_t[:, None]) & (offs_k < K)
|
||||
ptr_k = (k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d *
|
||||
(Hg * K) + offs_k * 1)
|
||||
ptr_k = k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * (Hg * K) + offs_k * 1
|
||||
b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32)
|
||||
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None])
|
||||
b_kb = b_k * b_beta[:, None] * b_g[:, None]
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
|
||||
ptr_w = (w + (bos * H + i_h) * K + offs_t_2d * (H * K) +
|
||||
offs_k * 1)
|
||||
ptr_w = w + (bos * H + i_h) * K + offs_t_2d * (H * K) + offs_k * 1
|
||||
tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k)
|
||||
|
||||
|
||||
@@ -90,14 +101,13 @@ def recompute_w_u_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) \
|
||||
if cu_seqlens is not None else None
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
BK = 64
|
||||
|
||||
@@ -30,15 +30,12 @@ def fused_gdn_gating_kernel(
|
||||
):
|
||||
i_b, i_s = tl.program_id(0), tl.program_id(1)
|
||||
for row_idx in range(0, ROW_ITER):
|
||||
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(
|
||||
0, BLK_BATCHES)
|
||||
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
|
||||
|
||||
for col_idx in range(0, COL_ITER):
|
||||
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
|
||||
|
||||
off = batch_off[:,
|
||||
None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[
|
||||
None, :]
|
||||
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
|
||||
head_mask = head_off < NUM_HEADS
|
||||
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
|
||||
|
||||
@@ -48,17 +45,14 @@ def fused_gdn_gating_kernel(
|
||||
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
|
||||
|
||||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
|
||||
softplus_x = tl.where(beta * x <= threshold,
|
||||
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
||||
softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
||||
|
||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||
|
||||
# compute beta_output = sigmoid(b)
|
||||
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
|
||||
tl.store(beta_output + off,
|
||||
blk_beta_output.to(beta_output.dtype.element_ty),
|
||||
mask=mask)
|
||||
tl.store(beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def fused_gdn_gating_patch(
|
||||
@@ -85,17 +79,13 @@ def fused_gdn_gating_patch(
|
||||
progs = num_cores
|
||||
FACTOR = 8 * num_heads
|
||||
row_per_core = triton.cdiv(batch, num_cores)
|
||||
BLK_BATCHES = triton.next_power_of_2(
|
||||
triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) //
|
||||
a.element_size()) // 2
|
||||
BLK_BATCHES = (
|
||||
triton.next_power_of_2(triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // a.element_size()) // 2
|
||||
)
|
||||
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
|
||||
|
||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||
beta_output = torch.empty(1,
|
||||
batch,
|
||||
num_heads,
|
||||
dtype=b.dtype,
|
||||
device=b.device)
|
||||
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
|
||||
|
||||
grid = (progs, seq_len)
|
||||
fused_gdn_gating_kernel[grid](
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
@@ -131,11 +132,7 @@ def layer_norm_fwd_npu(
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
@@ -168,4 +165,4 @@ def layer_norm_fwd_npu(
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
# Remove multibuffer if not needed
|
||||
)
|
||||
return out, mean, rstd
|
||||
return out, mean, rstd
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton # type: ignore
|
||||
@@ -61,22 +60,19 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
|
||||
valid_mask = col_indices < q_hidden_size
|
||||
input_values = (tl.load(input_ptr + input_offset + col_indices,
|
||||
mask=valid_mask,
|
||||
other=0.0).to(tl.float32).reshape(
|
||||
Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
|
||||
input_values = (
|
||||
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
.to(tl.float32)
|
||||
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
|
||||
)
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
|
||||
Q_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = (input_values * reciprocal_std
|
||||
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values +
|
||||
bias_values).to(tl.bfloat16)
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(
|
||||
tl.bfloat16)
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
@@ -93,8 +89,7 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
|
||||
dtype=tl.bfloat16)
|
||||
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
@@ -127,22 +122,19 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = (tl.load(input_ptr + input_offset + col_indices,
|
||||
mask=valid_mask,
|
||||
other=0.0).to(tl.float32).reshape(
|
||||
KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
|
||||
input_values = (
|
||||
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
.to(tl.float32)
|
||||
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
|
||||
)
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
|
||||
KV_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = (input_values * reciprocal_std
|
||||
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values +
|
||||
bias_values).to(tl.bfloat16)
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(
|
||||
tl.bfloat16)
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
@@ -158,8 +150,7 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
|
||||
dtype=tl.bfloat16)
|
||||
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
@@ -189,12 +180,8 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
for _ in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = tl.load(input_ptr + input_offset + col_indices,
|
||||
mask=valid_mask,
|
||||
other=0.0)
|
||||
tl.store(v_ptr + output_offset + col_indices,
|
||||
input_values,
|
||||
mask=valid_mask)
|
||||
input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
|
||||
@@ -209,27 +196,18 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
kv_hidden_size: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
q_bias: Optional[torch.Tensor] = None,
|
||||
k_bias: Optional[torch.Tensor] = None,
|
||||
q_bias: torch.Tensor | None = None,
|
||||
k_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
|
||||
assert KV_BLOCK_SIZE == head_dim
|
||||
assert head_dim == KV_BLOCK_SIZE
|
||||
assert q_hidden_size % kv_hidden_size == 0
|
||||
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
|
||||
batch_size = input.shape[0]
|
||||
total_hidden_size = q_hidden_size + kv_hidden_size * 2
|
||||
q_output = torch.empty(batch_size,
|
||||
q_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
k_output = torch.empty(batch_size,
|
||||
kv_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
v_output = torch.empty(batch_size,
|
||||
kv_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype)
|
||||
k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
|
||||
v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
|
||||
n_cols = kv_hidden_size // KV_BLOCK_SIZE
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
assert num_vectorcore % n_cols == 0
|
||||
@@ -271,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl_fake(
|
||||
kv_hidden_size: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
q_bias: Optional[torch.Tensor] = None,
|
||||
k_bias: Optional[torch.Tensor] = None,
|
||||
q_bias: torch.Tensor | None = None,
|
||||
k_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Fake implementation for shape inference during Dynamo/AOT tracing.
|
||||
# Note: sin and cos are not used in shape computation, but must be present in signature.
|
||||
@@ -298,8 +276,10 @@ def split_qkv_rmsnorm_rope_impl_fake(
|
||||
return q_output, k_output, v_output
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="qkv_rmsnorm_rope",
|
||||
op_func=split_qkv_rmsnorm_rope_impl,
|
||||
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
direct_register_custom_op(
|
||||
op_name="qkv_rmsnorm_rope",
|
||||
op_func=split_qkv_rmsnorm_rope_impl,
|
||||
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
@@ -7,24 +7,23 @@
|
||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
initial_states: torch.Tensor | None = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
final_states_out: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
@@ -42,19 +41,14 @@ def causal_conv1d_ref(
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
@@ -66,13 +60,13 @@ def causal_conv1d_ref(
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
metadata: Optional[Any] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
conv_states: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
query_start_loc: torch.Tensor | None = None,
|
||||
metadata: Any | None = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
@@ -126,10 +120,10 @@ def causal_conv1d_fn(
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]][..., :(
|
||||
width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
|
||||
if has_initial_state[i] else None))
|
||||
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None,
|
||||
)
|
||||
)
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
@@ -137,54 +131,50 @@ def causal_conv1d_fn(
|
||||
|
||||
@triton.jit
|
||||
def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# Pointers
|
||||
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr, # (num_cache_lines, dim, state_len)
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
block_idx_last_scheduled_token, # (batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
o_ptr, # same shape as x_ptr
|
||||
batch: tl.int32,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
||||
state_len: tl.constexpr, # effective state_len computed in wrapper
|
||||
num_cache_lines: tl.constexpr,
|
||||
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
|
||||
# Meta
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr, # <= 6
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
|
||||
# tiling
|
||||
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
|
||||
B_TILE: tl.constexpr, # batch tile
|
||||
T_CHUNK: tl.constexpr, # token chunk for state update
|
||||
# Pointers
|
||||
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr, # (num_cache_lines, dim, state_len)
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
block_idx_last_scheduled_token, # (batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
o_ptr, # same shape as x_ptr
|
||||
batch: tl.int32,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
||||
state_len: tl.constexpr, # effective state_len computed in wrapper
|
||||
num_cache_lines: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr, # <= 6
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
# tiling
|
||||
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
|
||||
B_TILE: tl.constexpr, # batch tile
|
||||
T_CHUNK: tl.constexpr, # token chunk for state update
|
||||
):
|
||||
# program ids
|
||||
pid_b = tl.program_id(0) # batch-tile id
|
||||
@@ -197,37 +187,30 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# preload weights once per program (shared by B_TILE sequences)
|
||||
w_base = w_ptr + idx_feats * stride_w_dim
|
||||
# define to avoid "undefined" in branches
|
||||
w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
if KERNEL_WIDTH >= 1:
|
||||
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 5:
|
||||
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 6:
|
||||
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
|
||||
# bias vector once per program
|
||||
if HAS_BIAS:
|
||||
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32)
|
||||
else:
|
||||
acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
|
||||
# token index vector for chunked copy
|
||||
tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK]
|
||||
@@ -241,36 +224,26 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# APC mapping (optional)
|
||||
# -------------------------
|
||||
if IS_APC_ENABLED:
|
||||
conv_state_init = tl.load(initial_state_idx + b,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int32)
|
||||
current_last_index = tl.load(block_idx_last_scheduled_token + b,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int32)
|
||||
conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32)
|
||||
current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32)
|
||||
else:
|
||||
conv_state_init = tl.full((), 0, tl.int32)
|
||||
current_last_index = tl.full((), 0, tl.int32)
|
||||
|
||||
# input cache line
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
b * stride_state_indices +
|
||||
conv_state_init,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
conv_states_input_coord = tl.load(
|
||||
conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0
|
||||
).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT:
|
||||
lane_active = lane_active & (conv_states_input_coord
|
||||
!= pad_slot_id)
|
||||
lane_active = lane_active & (conv_states_input_coord != pad_slot_id)
|
||||
|
||||
# -------------------------
|
||||
# varlen (optional): revise seqlen_run and state_len_run like original kernel does
|
||||
# -------------------------
|
||||
if IS_VARLEN:
|
||||
qs = tl.load(query_start_loc_ptr + b, mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
qe = tl.load(query_start_loc_ptr + (b + 1),
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64)
|
||||
qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64)
|
||||
seqlen_run = (qe - qs).to(tl.int32)
|
||||
# revise effective state_len for shorter sequences (same formula as original)
|
||||
state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32)
|
||||
@@ -289,9 +262,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# spec decoding offset (optional)
|
||||
# -------------------------
|
||||
if IS_SPEC_DECODING:
|
||||
conv_state_token_offset = (
|
||||
tl.load(num_accepted_tokens_ptr + b, mask=lane_active,
|
||||
other=1).to(tl.int64) - 1)
|
||||
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1
|
||||
shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode
|
||||
else:
|
||||
conv_state_token_offset = tl.full((), 0, tl.int64)
|
||||
@@ -300,37 +271,37 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# -------------------------
|
||||
# STEP 1: read initial history cols BEFORE state update (out==x safe)
|
||||
# -------------------------
|
||||
conv_states_base = (conv_state_ptr +
|
||||
conv_states_input_coord * stride_conv_state_seq +
|
||||
idx_feats * stride_conv_state_dim)
|
||||
conv_states_base = (
|
||||
conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
||||
)
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
|
||||
# define history vectors as zeros then load conditionally
|
||||
col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col0 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col1 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col2 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col3 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col4 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 5:
|
||||
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 6:
|
||||
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
|
||||
# -------------------------
|
||||
# STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block)
|
||||
@@ -340,29 +311,25 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run]
|
||||
# -------------------------
|
||||
# output cache line
|
||||
conv_states_offset = tl.load(conv_state_indices_ptr +
|
||||
b * stride_state_indices +
|
||||
current_last_index,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
conv_states_offset = tl.load(
|
||||
conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0
|
||||
).to(tl.int64)
|
||||
|
||||
use_shift = (seqlen_run < state_len_run)
|
||||
use_tail = (seqlen_run >= state_len_run)
|
||||
use_shift = seqlen_run < state_len_run
|
||||
use_tail = seqlen_run >= state_len_run
|
||||
|
||||
zero_i32 = tl.full((), 0, tl.int32)
|
||||
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run),
|
||||
zero_i32).to(tl.int32)
|
||||
tail_start = tl.where(use_tail, (seqlen_run - state_len_run),
|
||||
zero_i32).to(tl.int32)
|
||||
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32)
|
||||
tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32)
|
||||
|
||||
# base pointers
|
||||
state_src_base = (conv_state_ptr +
|
||||
conv_states_input_coord * stride_conv_state_seq +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
idx_feats * stride_conv_state_dim)
|
||||
state_dst_base = (conv_state_ptr +
|
||||
conv_states_offset * stride_conv_state_seq +
|
||||
idx_feats * stride_conv_state_dim)
|
||||
state_src_base = (
|
||||
conv_state_ptr
|
||||
+ conv_states_input_coord * stride_conv_state_seq
|
||||
+ conv_state_token_offset * stride_conv_state_tok
|
||||
+ idx_feats * stride_conv_state_dim
|
||||
)
|
||||
state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
||||
|
||||
x_base = x_ptr + x_offset + idx_feats * stride_x_dim
|
||||
|
||||
@@ -370,16 +337,16 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
|
||||
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
||||
src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK]
|
||||
m_tok = use_shift & (dst_tok < keep_shift) & (
|
||||
src_tok < state_len_run) & (dst_tok < state_len_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
|
||||
conv_states_input_coord
|
||||
< num_cache_lines) & (conv_states_offset < num_cache_lines)
|
||||
m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run)
|
||||
m = (
|
||||
(lane_active & m_tok)[:, None]
|
||||
& mask_w[None, :]
|
||||
& (conv_states_input_coord < num_cache_lines)
|
||||
& (conv_states_offset < num_cache_lines)
|
||||
)
|
||||
|
||||
src_ptrs = state_src_base[
|
||||
None, :] + src_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[
|
||||
None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
vals = tl.load(src_ptrs, mask=m, other=0.0)
|
||||
tl.store(dst_ptrs, vals, mask=m)
|
||||
|
||||
@@ -387,14 +354,11 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
for t0 in tl.static_range(0, seqlen, T_CHUNK):
|
||||
x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
||||
dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK]
|
||||
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok
|
||||
< state_len_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
|
||||
conv_states_offset < num_cache_lines)
|
||||
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
||||
|
||||
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
||||
dst_ptrs = state_dst_base[
|
||||
None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
||||
tl.store(dst_ptrs, x_vals, mask=m)
|
||||
|
||||
@@ -403,12 +367,10 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
||||
x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK]
|
||||
m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
|
||||
conv_states_offset < num_cache_lines)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
||||
|
||||
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
||||
dst_ptrs = state_dst_base[
|
||||
None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
||||
tl.store(dst_ptrs, x_vals, mask=m)
|
||||
|
||||
@@ -433,17 +395,13 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
if KERNEL_WIDTH == 1:
|
||||
# only x[t] * w0
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
matrix_w = w_col0
|
||||
elif KERNEL_WIDTH == 2:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 3:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -451,9 +409,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 4:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -464,9 +420,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 5:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -480,9 +434,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 4:
|
||||
matrix_w = w_col4
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 6:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -499,9 +451,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 5:
|
||||
matrix_w = w_col5
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
|
||||
acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N]
|
||||
|
||||
@@ -606,7 +556,7 @@ def causal_conv1d_update_npu(
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
if query_start_loc is None:
|
||||
batch, seqlen, dim = x.shape
|
||||
batch, seqlen, dim = x.shape
|
||||
else:
|
||||
assert conv_state_indices is not None
|
||||
batch = conv_state_indices.size(0)
|
||||
@@ -614,14 +564,14 @@ def causal_conv1d_update_npu(
|
||||
seqlen = max_query_len
|
||||
|
||||
width, _ = weight.shape
|
||||
num_cache_lines, state_len_total,_ = conv_state.size()
|
||||
num_cache_lines, state_len_total, _ = conv_state.size()
|
||||
|
||||
# overwrite-on-x strategy same as original
|
||||
out = x
|
||||
|
||||
stride_w_width, stride_w_dim = weight.stride()
|
||||
if query_start_loc is None:
|
||||
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
|
||||
stride_x_seq, stride_x_token, stride_x_dim = x.stride()
|
||||
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
|
||||
else:
|
||||
stride_x_token, stride_x_dim = x.stride()
|
||||
@@ -629,10 +579,8 @@ def causal_conv1d_update_npu(
|
||||
stride_o_token, stride_o_dim = out.stride()
|
||||
stride_o_seq = 0
|
||||
|
||||
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
0) if conv_state_indices is not None else 0
|
||||
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride()
|
||||
stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
||||
|
||||
# effective state_len exactly as original
|
||||
if num_accepted_tokens is not None:
|
||||
@@ -642,10 +590,10 @@ def causal_conv1d_update_npu(
|
||||
np2_statelen = triton.next_power_of_2(eff_state_len)
|
||||
|
||||
# -------- tiling heuristic--------
|
||||
#keep program count around ~[80..160]
|
||||
# keep program count around ~[80..160]
|
||||
# vector core 40
|
||||
# TODO: use driver to get the vector core num
|
||||
CORE_HINT = 40
|
||||
CORE_HINT = 40
|
||||
# channel tile: 512 when dim large (reduce tasks), else 256
|
||||
block_n = 512 if dim >= 512 else 256
|
||||
g = triton.cdiv(dim, block_n)
|
||||
@@ -669,6 +617,7 @@ def causal_conv1d_update_npu(
|
||||
triton.cdiv(batch, META["B_TILE"]),
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
_causal_conv1d_update_kernel_npu_tiled[grid](
|
||||
x,
|
||||
weight,
|
||||
|
||||
@@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
|
||||
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
|
||||
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
|
||||
draft_token_id1 = tl.get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = tl.get_element(target_argmax_id, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
if draft_token_id1 == target_argmax1:
|
||||
bonus_renew_1(
|
||||
@@ -79,9 +79,7 @@ def bonus_renew(
|
||||
num_tokens1,
|
||||
):
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
|
||||
tl.store(
|
||||
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
|
||||
bonus_token_id)
|
||||
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
@@ -106,17 +104,15 @@ def rejection_greedy_sample_triton(
|
||||
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
|
||||
is_greedy_mask = mask & (is_greedy != 0)
|
||||
|
||||
start_idx = tl.where(
|
||||
offset == 0, 0,
|
||||
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
|
||||
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
|
||||
num_tokens1 = tl.get_element(num_draft_tokens, (pos,))
|
||||
rejected = False
|
||||
start_idx1 = tl.get_element(start_idx, (pos, ))
|
||||
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
|
||||
start_idx1 = tl.get_element(start_idx, (pos,))
|
||||
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
for i in range(num_tokens1):
|
||||
if not rejected:
|
||||
@@ -142,50 +138,44 @@ def rejection_greedy_sample_triton(
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
vec_len,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
vec_len,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < vec_len
|
||||
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
|
||||
not_greedy_mask = is_greedy == 0
|
||||
start_idxs = tl.where(
|
||||
offsets == 0, 0,
|
||||
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||
start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
|
||||
if not_greedy:
|
||||
rejected = False
|
||||
start_idx = tl.get_element(start_idxs, (req_i, ))
|
||||
start_idx = tl.get_element(start_idxs, (req_i,))
|
||||
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx +
|
||||
pos)
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
@@ -195,17 +185,13 @@ def rejection_random_sample_kernel(
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr +
|
||||
start_idx + pos)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
pos, token_id)
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens,
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
|
||||
@@ -225,8 +211,7 @@ def expand_kernel(
|
||||
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
len_mask = offset < vec_len
|
||||
|
||||
start_idx = tl.where(offset == 0, 0,
|
||||
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
|
||||
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
|
||||
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
@@ -234,13 +219,11 @@ def expand_kernel(
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
|
||||
for i in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_tokens, (i, ))
|
||||
start_idx1 = tl.get_element(start_idx, (i, ))
|
||||
src_val1 = tl.get_element(src_val, (i, ))
|
||||
num_tokens1 = tl.get_element(num_tokens, (i,))
|
||||
start_idx1 = tl.get_element(start_idx, (i,))
|
||||
src_val1 = tl.get_element(src_val, (i,))
|
||||
offset1 = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx1 + offset1,
|
||||
src_val1,
|
||||
mask=offset1 < num_tokens1)
|
||||
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -257,8 +240,7 @@ def sample_recovered_tokens_kernel(
|
||||
SUB_BLOCK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
||||
req_idx - 1)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
@@ -272,27 +254,25 @@ def sample_recovered_tokens_kernel(
|
||||
global_max_p = -1.0
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
0)
|
||||
tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0)
|
||||
for loop_i in range(loop):
|
||||
vocab_start = loop_i * SUB_BLOCK
|
||||
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
|
||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"))
|
||||
prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
q = tl.load(
|
||||
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
|
||||
)
|
||||
new_p = prob / q
|
||||
recovered_id = tl.argmax(new_p, axis=-1)
|
||||
max_p = tl.get_element(new_p, (recovered_id, ))
|
||||
max_p = tl.get_element(new_p, (recovered_id,))
|
||||
if max_p > global_max_p:
|
||||
global_max_p = max_p
|
||||
global_recovered_id = vocab_start + recovered_id
|
||||
@@ -300,25 +280,24 @@ def sample_recovered_tokens_kernel(
|
||||
for loop_i in range(loop):
|
||||
vocab_start = loop_i * SUB_BLOCK
|
||||
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"))
|
||||
q = tl.load(
|
||||
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
|
||||
)
|
||||
new_p = prob / q
|
||||
recovered_id = tl.argmax(new_p, axis=-1)
|
||||
max_p = tl.get_element(new_p, (recovered_id, ))
|
||||
max_p = tl.get_element(new_p, (recovered_id,))
|
||||
if max_p > global_max_p:
|
||||
global_max_p = max_p
|
||||
global_recovered_id = vocab_start + recovered_id
|
||||
@@ -327,21 +306,25 @@ def sample_recovered_tokens_kernel(
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob)
|
||||
|
||||
|
||||
def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
|
||||
cu_num_draft_tokens, draft_token_ids,
|
||||
target_argmax, bonus_token_ids,
|
||||
is_greedy, max_spec_len, grid,
|
||||
block_size):
|
||||
def rejection_greedy_sample_with_triton(
|
||||
output_token_ids,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
grid,
|
||||
block_size,
|
||||
):
|
||||
vec_len = output_token_ids.shape[0]
|
||||
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and is_greedy is None:
|
||||
rejection_greedy_sample_spec_len_1_triton[(grid, )](
|
||||
if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and is_greedy is None:
|
||||
rejection_greedy_sample_spec_len_1_triton[(grid,)](
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
@@ -350,7 +333,7 @@ def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
else:
|
||||
rejection_greedy_sample_triton[(grid, )](
|
||||
rejection_greedy_sample_triton[(grid,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -363,12 +346,11 @@ def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
|
||||
)
|
||||
|
||||
|
||||
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
|
||||
replace_to, max_num_tokens):
|
||||
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens):
|
||||
vec_len = batch_size
|
||||
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||
|
||||
expand_kernel[(grid, )](
|
||||
expand_kernel[(grid,)](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
@@ -382,56 +364,50 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_block_verify_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
vec_len,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
vec_len,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < vec_len
|
||||
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
|
||||
not_greedy_mask = is_greedy == 0
|
||||
start_idxs = tl.where(
|
||||
offsets == 0, 0,
|
||||
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||
start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
|
||||
if not_greedy:
|
||||
|
||||
rejected = False
|
||||
pi = 1.0
|
||||
uniform_prob = 1.0
|
||||
last_accepted_token_pos = -1
|
||||
start_idx = tl.get_element(start_idxs, (req_i, ))
|
||||
start_idx = tl.get_element(start_idxs, (req_i,))
|
||||
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
|
||||
|
||||
for pos in range(num_draft_tokens):
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
uniform_prob = uniform_prob * tmp_uniform_prob
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
|
||||
pi = min(pi * target_prob / draft_prob, 1.0)
|
||||
if draft_prob > 0 and pi >= uniform_prob:
|
||||
@@ -443,19 +419,14 @@ def rejection_random_sample_block_verify_kernel(
|
||||
if last_accepted_token_pos > -1:
|
||||
for pos in range(last_accepted_token_pos + 1):
|
||||
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
pos, token_id)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
|
||||
|
||||
if rejected:
|
||||
recovered_token_id = tl.load(recovered_token_ids_ptr +
|
||||
start_idx +
|
||||
last_accepted_token_pos + 1)
|
||||
recovered_token_id = tl.load(recovered_token_ids_ptr + start_idx + last_accepted_token_pos + 1)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
last_accepted_token_pos + 1, recovered_token_id)
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + last_accepted_token_pos + 1,
|
||||
recovered_token_id,
|
||||
)
|
||||
else:
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id)
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
import torch
|
||||
from typing import Tuple
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@@ -48,16 +49,16 @@ def _triton_rope(
|
||||
This triton kernel applies rotary embedding on q and k.
|
||||
It supports rope_dim != head_dim scenario.
|
||||
It supports both neox style and non-neox style rope computation.
|
||||
|
||||
|
||||
Input tensor layout assumptions:
|
||||
|
||||
|
||||
q size: (num_tokens, num_q_heads, head_dim)
|
||||
q stride: (num_q_heads * head_dim, head_dim, 1)
|
||||
k size: (num_tokens, num_kv_heads, head_dim)
|
||||
k stride: (num_kv_heads * head_dim, head_dim, 1)
|
||||
cos/sin size: (num_tokens, rope_dim/2)
|
||||
cos/sin stride: (rope_dim/2, 1)
|
||||
|
||||
|
||||
Different compute pattern of IS_NEOX_STYLE:
|
||||
|
||||
if IS_NEOX_STYLE:
|
||||
@@ -88,10 +89,8 @@ def _triton_rope(
|
||||
|
||||
cos_offsets = tl.arange(0, pad_rope_dim // 2)
|
||||
cos_mask = cos_offsets < (rope_dim // 2)
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask,
|
||||
other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask,
|
||||
other=0).to(tl.float32)
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
|
||||
# ####################################################################
|
||||
# Load the left and right half of q and k for the current
|
||||
@@ -99,28 +98,20 @@ def _triton_rope(
|
||||
# ####################################################################
|
||||
# left half of the head
|
||||
if IS_NEOX_STYLE:
|
||||
first_half_q_offsets = tl.arange(
|
||||
0, pad_n_qh)[:, None] * hd + tl.arange(
|
||||
0, pad_rope_dim // 2)[None, :]
|
||||
first_half_k_offsets = tl.arange(
|
||||
0, pad_n_kh)[:, None] * hd + tl.arange(
|
||||
0, pad_rope_dim // 2)[None, :]
|
||||
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
|
||||
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
|
||||
else:
|
||||
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (
|
||||
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
|
||||
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (
|
||||
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
|
||||
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
|
||||
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
|
||||
|
||||
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
|
||||
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
|
||||
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
|
||||
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
|
||||
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets,
|
||||
mask=first_q_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets,
|
||||
mask=first_k_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
||||
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
|
||||
)
|
||||
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
||||
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
|
||||
)
|
||||
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
||||
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
||||
|
||||
# right half of the head
|
||||
if IS_NEOX_STYLE:
|
||||
@@ -131,41 +122,29 @@ def _triton_rope(
|
||||
second_half_k_offsets = first_half_k_offsets + 1
|
||||
second_q_mask = first_q_mask
|
||||
second_k_mask = first_k_mask
|
||||
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets,
|
||||
mask=second_q_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets,
|
||||
mask=second_k_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
||||
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
||||
|
||||
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
||||
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
||||
tl.store(q_start_ptr + first_half_q_offsets,
|
||||
new_q_tile_1,
|
||||
mask=first_q_mask)
|
||||
tl.store(q_start_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
||||
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
||||
tl.store(q_start_ptr + second_half_q_offsets,
|
||||
new_q_tile_2,
|
||||
mask=second_q_mask)
|
||||
tl.store(q_start_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
||||
|
||||
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
||||
tl.store(k_start_ptr + first_half_k_offsets,
|
||||
new_k_tile_1,
|
||||
mask=first_k_mask)
|
||||
tl.store(k_start_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
||||
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
||||
tl.store(k_start_ptr + second_half_k_offsets,
|
||||
new_k_tile_2,
|
||||
mask=second_k_mask)
|
||||
tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
||||
|
||||
|
||||
def rope_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
rope_dim: int = -1,
|
||||
is_neox_style: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
rope_dim: int = -1,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if not q.is_contiguous():
|
||||
q = q.contiguous()
|
||||
if not k.is_contiguous():
|
||||
@@ -187,7 +166,7 @@ def rope_forward_triton(
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
n_row = min(num_tokens, num_vectorcore)
|
||||
|
||||
_triton_rope[(n_row, )](
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
|
||||
@@ -49,20 +49,15 @@ def prepare_inputs_padded_kernel(
|
||||
other=0,
|
||||
)
|
||||
|
||||
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev,
|
||||
cu_draft_curr)
|
||||
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, cu_draft_curr)
|
||||
|
||||
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets,
|
||||
mask=mask)
|
||||
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, mask=mask)
|
||||
num_rejected = num_draft_tokens + 1 - valid_count
|
||||
num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0)
|
||||
|
||||
# query_start_loc[req_idx + 1] is the start position of the next request,
|
||||
# which is one past the last token of this request.
|
||||
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1,
|
||||
mask=mask) - 1
|
||||
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, mask=mask) - 1
|
||||
|
||||
index_to_sample = q_last_tok_idx - num_rejected
|
||||
tl.store(token_indices_to_sample_ptr + offsets,
|
||||
index_to_sample,
|
||||
mask=mask)
|
||||
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
@@ -10,9 +10,9 @@ _NUM_VECTORCORE = -1
|
||||
def init_device_properties_triton():
|
||||
global _NUM_AICORE, _NUM_VECTORCORE
|
||||
if _NUM_AICORE == -1 and HAS_TRITON:
|
||||
device_properties: Dict[str, Any] = (
|
||||
triton.runtime.driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device()))
|
||||
device_properties: dict[str, Any] = triton.runtime.driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device()
|
||||
)
|
||||
_NUM_AICORE = device_properties.get("num_aicore", -1)
|
||||
_NUM_VECTORCORE = device_properties.get("num_vectorcore", -1)
|
||||
assert _NUM_AICORE > 0 and _NUM_VECTORCORE > 0, "Failed to detect device properties."
|
||||
|
||||
Reference in New Issue
Block a user