### What this PR does / why we need it?
Add Qwen3-next support.
### Does this PR introduce _any_ user-facing change?
Yes, users can use Qwen3 next.
Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the
tutorial will be ready in
[here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html)
### How was this patch tested?
Doc CI passed
Related: https://github.com/vllm-project/vllm-ascend/issues/2884
Co-Authored-By: Angazenn <supperccell@163.com>
Co-Authored-By: zzzzwwjj <1183291235@qq.com>
Co-Authored-By: MengqingCao <cmq0113@163.com>
Co-Authored-By: linfeng-yuan <1102311262@qq.com>
Co-Authored-By: hust17yixuan <303660421@qq.com>
Co-Authored-By: SunnyLee219 <3294305115@qq.com>
Co-Authored-By: maoxx241 <maoxx241@umn.edu>
- vLLM version: v0.10.2
- vLLM main:
b834b4cbf1
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: hust17yixuan <303660421@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: hust17yixuan <303660421@qq.com>
382 lines
12 KiB
Python
382 lines
12 KiB
Python
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
|
|
# Copyright (c) 2024, Tri Dao.
|
|
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
|
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
|
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
|
# mypy: ignore-errors
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import triton
|
|
import triton.language as tl
|
|
from einops import rearrange
|
|
|
|
|
|
def rms_norm_ref(
|
|
x,
|
|
weight,
|
|
bias,
|
|
z=None,
|
|
eps=1e-6,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
upcast=True,
|
|
):
|
|
dtype = x.dtype
|
|
#N = x.shape[-1]
|
|
weight = weight.float()
|
|
bias = bias.float() if bias is not None else None
|
|
if upcast:
|
|
x = x.float()
|
|
z = z.float() if z is not None else z
|
|
if z is not None and not norm_before_gate:
|
|
x = x * F.silu(z)
|
|
if group_size is None:
|
|
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
|
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
|
weight)
|
|
else:
|
|
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
|
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
|
eps)
|
|
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
|
if bias is not None:
|
|
out = out + bias
|
|
if z is not None and norm_before_gate:
|
|
out *= F.silu(z)
|
|
return out.to(dtype)
|
|
|
|
|
|
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
|
@triton.jit
|
|
def _layer_norm_fwd_1pass_kernel(
|
|
X, # pointer to the input
|
|
Y, # pointer to the output
|
|
W, # pointer to the weights
|
|
B, # pointer to the biases
|
|
Z, # pointer to the other branch
|
|
Mean, # pointer to the mean
|
|
Rstd, # pointer to the 1/std
|
|
stride_x_row, # how much to increase the pointer when moving by 1 row
|
|
stride_y_row,
|
|
stride_z_row,
|
|
M, # number of rows in X
|
|
N, # number of columns in X
|
|
eps, # epsilon to avoid division by zero
|
|
BLOCK_N: tl.constexpr,
|
|
HAS_BIAS: tl.constexpr,
|
|
HAS_Z: tl.constexpr,
|
|
NORM_BEFORE_GATE: tl.constexpr,
|
|
IS_RMS_NORM: tl.constexpr,
|
|
):
|
|
# Map the program id to the row of X and Y it should compute.
|
|
row = tl.program_id(0)
|
|
group = tl.program_id(1)
|
|
X += row * stride_x_row + group * N
|
|
Y += row * stride_y_row + group * N
|
|
if HAS_Z:
|
|
Z += row * stride_z_row + group * N
|
|
if not IS_RMS_NORM:
|
|
Mean += group * M
|
|
Rstd += group * M
|
|
W += group * N
|
|
if HAS_BIAS:
|
|
B += group * N
|
|
# Compute mean and variance
|
|
cols = tl.arange(0, BLOCK_N)
|
|
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
|
if HAS_Z and not NORM_BEFORE_GATE:
|
|
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
|
x *= z * tl.sigmoid(z)
|
|
if not IS_RMS_NORM:
|
|
mean = tl.sum(x, axis=0) / N
|
|
tl.store(Mean + row, mean)
|
|
xbar = tl.where(cols < N, x - mean, 0.0)
|
|
var = tl.sum(xbar * xbar, axis=0) / N
|
|
else:
|
|
xbar = tl.where(cols < N, x, 0.0)
|
|
var = tl.sum(xbar * xbar, axis=0) / N
|
|
rstd = 1 / tl.sqrt(var + eps)
|
|
tl.store(Rstd + row, rstd)
|
|
# Normalize and apply linear transformation
|
|
mask = cols < N
|
|
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
|
if HAS_BIAS:
|
|
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
|
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
|
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
|
if HAS_Z and NORM_BEFORE_GATE:
|
|
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
|
y *= z * tl.sigmoid(z)
|
|
# Write output
|
|
tl.store(Y + cols, y, mask=mask)
|
|
|
|
|
|
def _layer_norm_fwd(
|
|
x,
|
|
weight,
|
|
bias,
|
|
eps,
|
|
z=None,
|
|
out=None,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
is_rms_norm=False,
|
|
):
|
|
M, N = x.shape
|
|
if group_size is None:
|
|
group_size = N
|
|
assert N % group_size == 0
|
|
ngroups = N // group_size
|
|
assert x.stride(-1) == 1
|
|
if z is not None:
|
|
assert z.stride(-1) == 1
|
|
assert z.shape == (M, N)
|
|
assert weight.shape == (N, )
|
|
assert weight.stride(-1) == 1
|
|
if bias is not None:
|
|
assert bias.stride(-1) == 1
|
|
assert bias.shape == (N, )
|
|
# allocate output
|
|
if out is not None:
|
|
assert out.shape == x.shape
|
|
else:
|
|
out = torch.empty_like(x)
|
|
assert out.stride(-1) == 1
|
|
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
|
if not is_rms_norm else None)
|
|
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
|
# Less than 64KB per feature: enqueue fused kernel
|
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
|
if group_size > BLOCK_N:
|
|
raise RuntimeError(
|
|
"This layer norm doesn't support feature dim >= 64KB.")
|
|
# heuristics for number of warps
|
|
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
|
grid = (M, ngroups)
|
|
with torch.npu.device(x.device.index):
|
|
_layer_norm_fwd_1pass_kernel[grid](
|
|
x,
|
|
out,
|
|
weight,
|
|
bias,
|
|
z,
|
|
mean,
|
|
rstd,
|
|
x.stride(0),
|
|
out.stride(0),
|
|
z.stride(0) if z is not None else 0,
|
|
M,
|
|
group_size,
|
|
eps,
|
|
BLOCK_N=BLOCK_N,
|
|
NORM_BEFORE_GATE=norm_before_gate,
|
|
IS_RMS_NORM=is_rms_norm,
|
|
num_warps=num_warps,
|
|
)
|
|
return out, mean, rstd
|
|
|
|
|
|
class LayerNormFn(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
x,
|
|
weight,
|
|
bias,
|
|
z=None,
|
|
eps=1e-6,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
is_rms_norm=False,
|
|
):
|
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
|
|
x_shape_og = x.shape
|
|
# reshape input data into 2D tensor
|
|
x = x.reshape(-1, x.shape[-1])
|
|
if x.stride(-1) != 1:
|
|
x = x.contiguous()
|
|
if z is not None:
|
|
assert z.shape == x_shape_og
|
|
z = z.reshape(-1, z.shape[-1])
|
|
if z.stride(-1) != 1:
|
|
z = z.contiguous()
|
|
weight = weight.contiguous()
|
|
if bias is not None:
|
|
bias = bias.contiguous()
|
|
y, mean, rstd = _layer_norm_fwd(
|
|
x,
|
|
weight,
|
|
bias,
|
|
eps,
|
|
z=z,
|
|
group_size=group_size,
|
|
norm_before_gate=norm_before_gate,
|
|
is_rms_norm=is_rms_norm,
|
|
)
|
|
return y.reshape(x_shape_og)
|
|
|
|
|
|
def layernorm_fn(
|
|
x,
|
|
weight,
|
|
bias,
|
|
z=None,
|
|
eps=1e-6,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
is_rms_norm=False,
|
|
):
|
|
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
|
norm_before_gate, is_rms_norm)
|
|
|
|
|
|
def rmsnorm_fn(x,
|
|
weight,
|
|
bias,
|
|
z=None,
|
|
eps=1e-6,
|
|
group_size=None,
|
|
norm_before_gate=True):
|
|
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
|
norm_before_gate, True)
|
|
|
|
|
|
class LayerNorm(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
eps=1e-5,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
|
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
|
"""
|
|
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = torch.nn.Parameter(
|
|
torch.empty(hidden_size, **factory_kwargs))
|
|
self.bias = torch.nn.Parameter(
|
|
torch.empty(hidden_size, **factory_kwargs))
|
|
self.group_size = group_size
|
|
self.norm_before_gate = norm_before_gate
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.ones_(self.weight)
|
|
torch.nn.init.zeros_(self.bias)
|
|
|
|
def forward(self, x, z=None):
|
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
return layernorm_fn(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
z=z,
|
|
group_size=self.group_size,
|
|
eps=self.eps,
|
|
norm_before_gate=self.norm_before_gate,
|
|
)
|
|
|
|
|
|
class RMSNormGated(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
eps=1e-5,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
|
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
|
"""
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = torch.nn.Parameter(
|
|
torch.empty(hidden_size, **factory_kwargs))
|
|
self.register_parameter("bias", None)
|
|
self.group_size = group_size
|
|
self.norm_before_gate = norm_before_gate
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.ones_(self.weight)
|
|
|
|
def forward(self, x, z=None):
|
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
return rmsnorm_fn(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
z=z,
|
|
eps=self.eps,
|
|
group_size=self.group_size,
|
|
norm_before_gate=self.norm_before_gate,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def fused_gdn_gating_kernel(
|
|
g,
|
|
A_log,
|
|
a,
|
|
dt_bias,
|
|
seq_len,
|
|
NUM_HEADS: tl.constexpr,
|
|
beta: tl.constexpr,
|
|
threshold: tl.constexpr,
|
|
BLK_HEADS: tl.constexpr,
|
|
):
|
|
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
|
|
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
|
|
mask = head_off < NUM_HEADS
|
|
blk_A_log = tl.load(A_log + head_off, mask=mask)
|
|
blk_a = tl.load(a + off, mask=mask)
|
|
blk_bias = tl.load(dt_bias + head_off, mask=mask)
|
|
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
|
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
|
|
softplus_x = tl.where(beta * x <= threshold,
|
|
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
|
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
|
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
|
|
|
|
|
def fused_gdn_gating(
|
|
A_log: torch.Tensor,
|
|
a: torch.Tensor,
|
|
dt_bias: torch.Tensor,
|
|
beta: float = 1.0,
|
|
threshold: float = 20.0,
|
|
) -> torch.Tensor:
|
|
batch, num_heads = a.shape
|
|
seq_len = 1
|
|
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
|
g = torch.empty_like(a, dtype=torch.float32)
|
|
fused_gdn_gating_kernel[grid](g,
|
|
A_log,
|
|
a,
|
|
dt_bias,
|
|
seq_len,
|
|
num_heads,
|
|
beta,
|
|
threshold,
|
|
8,
|
|
num_warps=1)
|
|
return g
|