First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,158 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
if hasattr(tl, "libdevice"):
tl_math = tl.libdevice
else:
tl_math = tl.math
@triton.jit
def _rms_norm_kernel(
x_ptr,
h1_ptr,
w_ptr,
eps,
stride,
N_COLS,
BLOCK_SIZE: tl.constexpr,
INCLUDE_WEIGHT: tl.constexpr,
):
row = tl.program_id(0)
x_ptr += row * stride
h1_ptr += row * stride
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
a = tl.load(
x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
).to(tl.float32)
_mean += a * a
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N_COLS
a = tl.load(
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
).to(tl.float32)
if INCLUDE_WEIGHT:
w = tl.load(w_ptr + cols, mask=mask)
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
else:
tl.store(h1_ptr + cols, a * rstd, mask=mask)
@triton.jit
def _rms_norm_add_kernel(
x_ptr,
y_ptr,
h1_ptr,
w_ptr,
eps,
stride,
N_COLS,
BLOCK_SIZE: tl.constexpr,
INCLUDE_WEIGHT: tl.constexpr,
):
row = tl.program_id(0)
x_ptr += row * stride
y_ptr += row * stride
h1_ptr += row * stride
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N_COLS
ax = tl.load(
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last"
).to(tl.float32)
ay = tl.load(
y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
).to(tl.float32)
a = ax + ay
tl.store(x_ptr + cols, a, mask=mask)
_mean += a * a
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N_COLS
a = tl.load(
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
).to(tl.float32)
if INCLUDE_WEIGHT:
w = tl.load(w_ptr + cols, mask=mask)
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
else:
tl.store(h1_ptr + cols, a * rstd, mask=mask)
def _rms_norm_forward(x, attn_norm_weights, eps):
if not x.is_contiguous():
raise ValueError("data must be contiguous")
if attn_norm_weights is not None:
if not attn_norm_weights.is_contiguous():
raise ValueError("weights must be contiguous")
out = torch.empty_like(x)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_rms_norm_kernel[(M,)](
x_arg,
out,
attn_norm_weights,
eps,
x_arg.stride(0),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
INCLUDE_WEIGHT=attn_norm_weights is not None,
)
return out
def _rms_norm_add_forward(x, y, attn_norm_weights, eps):
# x, y contiguous of same shape [..., n]
# output of same shape, normed over the last dim.
if not x.is_contiguous():
raise ValueError("x must be contiguous")
if not y.is_contiguous():
raise ValueError("y must be contiguous")
if attn_norm_weights is not None:
if not attn_norm_weights.is_contiguous():
raise ValueError("weights must be contiguous")
out = torch.empty_like(x)
x_arg = x.reshape(-1, x.shape[-1])
y_arg = y.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_rms_norm_add_kernel[(M,)](
x_arg,
y_arg,
out,
attn_norm_weights,
eps,
x_arg.stride(0),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
INCLUDE_WEIGHT=attn_norm_weights is not None,
)
return out

View File

@@ -0,0 +1,161 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import triton # type: ignore
import triton.language as tl # type: ignore
if hasattr(tl, "libdevice"):
tl_math = tl.libdevice
else:
tl_math = tl.math
@triton.jit
def _rope_padded_kernel(
xq,
xk,
xv,
out_q,
cache_k,
cache_v,
seqstartq,
seqstartk,
seqlenk,
theta,
k_start: tl.constexpr,
v_start: tl.constexpr,
dim: tl.constexpr, # dimension of each head
stride_xqM,
stride_xqH,
stride_xkM,
stride_xkH,
stride_xvM,
stride_xvH,
stride_cachekM,
stride_cachekH,
stride_cachevM,
stride_cachevH,
stride_seqstartq,
stride_seqstartk,
stride_seqlenk,
stride_outqM,
stride_outqH,
internal_dtype: tl.constexpr,
# If True, seqstartq and seqstartk are not used but rather we
# assume that every batch element has the same number of
# queries (i.e. num_queries := tl.num_programs(1) )
# and the same cache space cache_padding_length.
# Always False when called below.
const_batch_strides: tl.constexpr,
# If const_batch_strides==True, the common cache length for each batch element.
# (Only the first seqlenk[i] elements are actually in use, and only the last
# num_queries of those are actually written to.)
cache_padding_length,
# offset added to all values in seqlenk before using them.
# Always 0 when called below.
seqlenk_shift: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
adjacents: tl.constexpr,
):
"""
Each letter in this diagram is a whole row of length dim.
INPUT xq xk xv
head_dim ─►
batch qqqqqq kk vv
│ qqqqqq kk vv
▼ qqqqqq kk vv
head_idx: (goes across all heads of all 3 inputs)
▲ ▲ ▲ ▲ ▲ ▲
│ │ │ │ │ │
│ │
0 k_start │v_start │n_total_heads
│ │
│ │
k_start v_start
Output is to out_q (same shape as xq), an xk-shaped part
of cache_k and an xv-shaped part of cache_v
"""
batch_elt = tl.program_id(0)
query_pos_in_batch_elt = tl.program_id(1)
head_idx = tl.program_id(2)
if internal_dtype == "f32":
theta = theta.to(tl.float32)
elif internal_dtype == "f64":
theta = theta.to(tl.float64)
if const_batch_strides:
query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt
end_query_pos = tl.num_programs(1) * (batch_elt + 1)
else:
query_pos = query_pos_in_batch_elt + tl.load(
seqstartq + batch_elt * stride_seqstartq
)
end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq)
if query_pos >= end_query_pos:
return
is_q = head_idx < k_start
is_v = head_idx >= v_start
xq += query_pos * stride_xqM + head_idx * stride_xqH
out_q += query_pos * stride_outqM + head_idx * stride_outqH
if const_batch_strides:
cache_start = cache_padding_length * batch_elt
else:
cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk)
end_of_batch_elt_cache = (
cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift
)
cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos)
seq_pos = cache_pos - cache_start
cache_k += (head_idx - k_start) * stride_cachekH + cache_pos * stride_cachekM
xk += query_pos * stride_xkM + (head_idx - k_start) * stride_xkH
in_qk = tl.where(is_q, xq, xk)
out_qk = tl.where(is_q, out_q, cache_k)
cache_v += (head_idx - v_start) * stride_cachevH + cache_pos * stride_cachevM
xv += query_pos * stride_xvM + (head_idx - v_start) * stride_xvH
out = tl.where(is_v, cache_v, out_qk)
x_in = tl.where(is_v, xv, in_qk)
for offset in range(0, dim // 2, BLOCK_SIZE // 2):
c = tl.arange(0, BLOCK_SIZE // 2)
powers = (offset + c) * 2.0
if adjacents:
cols_re = (offset + c) * 2
cols_im = cols_re + 1
else:
cols_re = offset + c
cols_im = cols_re + dim // 2
mask = cols_im < dim
re_x = tl.load(x_in + cols_re, mask=mask)
im_x = tl.load(x_in + cols_im, mask=mask)
# freqs = seq_pos / (theta ** (powers / dim))
freqs = seq_pos * tl_math.pow(theta, powers / (-dim))
sines = tl.sin(freqs)
cosines = tl.cos(freqs)
re_out = re_x * cosines - im_x * sines
im_out = im_x * cosines + re_x * sines
re_out_ = tl.where(is_v, re_x, re_out)
im_out_ = tl.where(is_v, im_x, im_out)
if internal_dtype == "f64":
if re_x.dtype == tl.bfloat16:
# triton 2.0.0 crashes if you try to convert
# float64 directly to bfloat16, so make an intermediate step.
re_out_ = re_out_.to(tl.float32)
im_out_ = im_out_.to(tl.float32)
tl.store(out + cols_re, re_out_, mask=mask)
tl.store(out + cols_im, im_out_, mask=mask)